Unverified Commit 4efa3320 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Move server start code to initialize. (#2002)

* move server start code to initialize.

* fix.

* fix lint.

* fix examples.

* add more checks.
parent af990989
...@@ -260,10 +260,9 @@ def run(args, device, data): ...@@ -260,10 +260,9 @@ def run(args, device, data):
print(profiler.output_text(unicode=True, color=True)) print(profiler.output_text(unicode=True, color=True))
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
......
...@@ -418,9 +418,9 @@ def run(args, device, data): ...@@ -418,9 +418,9 @@ def run(args, device, data):
th.save(pred, 'emb.pt') th.save(pred, 'emb.pt')
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
print('number of edges', g.number_of_edges()) print('number of edges', g.number_of_edges())
...@@ -452,7 +452,7 @@ if __name__ == '__main__': ...@@ -452,7 +452,7 @@ if __name__ == '__main__':
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-servers', type=int, help='Server count on each machine.') parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.')
parser.add_argument('--n-classes', type=int, help='the number of classes') parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0, parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training") help="GPU device ID. Use -1 for CPU training")
......
...@@ -15,22 +15,3 @@ from .kvstore import KVServer, KVClient ...@@ -15,22 +15,3 @@ from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .dist_dataloader import DistDataLoader from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph, find_edges from .graph_services import sample_neighbors, in_subgraph, find_edges
if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
SERV.start()
sys.exit()
...@@ -5,6 +5,7 @@ import traceback ...@@ -5,6 +5,7 @@ import traceback
import atexit import atexit
import time import time
import os import os
import sys
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
...@@ -63,22 +64,44 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -63,22 +64,44 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_worker_threads: int num_worker_threads: int
The number of threads in a worker process. The number of threads in a worker process.
""" """
rpc.reset() if os.environ.get('DGL_ROLE', 'client') == 'server':
ctx = mp.get_context("spawn") from .dist_graph import DistGraphServer
global SAMPLER_POOL assert os.environ.get('DGL_SERVER_ID') is not None, \
global NUM_SAMPLER_WORKERS 'Please define DGL_SERVER_ID to run DistGraph server'
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone' assert os.environ.get('DGL_IP_CONFIG') is not None, \
if num_workers > 0 and not is_standalone: 'Please define DGL_IP_CONFIG to run DistGraph server'
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc, assert os.environ.get('DGL_NUM_SERVER') is not None, \
initargs=(ip_config, num_servers, max_queue_size, 'Please define DGL_NUM_SERVER to run DistGraph server'
net_type, 'sampler', num_worker_threads)) assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
serv.start()
sys.exit()
else: else:
SAMPLER_POOL = None rpc.reset()
NUM_SAMPLER_WORKERS = num_workers ctx = mp.get_context("spawn")
if not is_standalone: global SAMPLER_POOL
connect_to_server(ip_config, num_servers, max_queue_size, net_type) global NUM_SAMPLER_WORKERS
init_role('default') is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
init_kvstore(ip_config, num_servers, 'default') if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
initargs=(ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
......
...@@ -125,8 +125,18 @@ def main(): ...@@ -125,8 +125,18 @@ def main():
the same machine. By default, it is 1.') the same machine. By default, it is 1.')
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.' assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers > 0, '--num_trainers must be a positive number.' assert args.num_trainers is not None and args.num_trainers > 0, \
assert args.num_samplers >= 0 '--num_trainers must be a positive number.'
assert args.num_samplers is not None and args.num_samplers >= 0, \
'--num_samplers must be a non-negative number.'
assert args.num_servers is not None and args.num_servers > 0, \
'--num_servers must be a positive number.'
assert args.num_server_threads > 0, '--num_server_threads must be a positive number.'
assert args.workspace is not None, 'A user has to specify a workspace with --workspace.'
assert args.part_config is not None, \
'A user has to specify a partition configuration file with --part_config.'
assert args.ip_config is not None, \
'A user has to specify an IP configuration file with --ip_config.'
udf_command = str(udf_command[0]) udf_command = str(udf_command[0])
if 'python' not in udf_command: if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.") raise RuntimeError("DGL launching script can only support Python executable file.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment