Unverified Commit 4efa3320 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Move server start code to initialize. (#2002)

* move server start code to initialize.

* fix.

* fix lint.

* fix examples.

* add more checks.
parent af990989
......@@ -260,10 +260,9 @@ def run(args, device, data):
print(profiler.output_text(unicode=True, color=True))
def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank())
......
......@@ -418,9 +418,9 @@ def run(args, device, data):
th.save(pred, 'emb.pt')
def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank())
print('number of edges', g.number_of_edges())
......@@ -452,7 +452,7 @@ if __name__ == '__main__':
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-servers', type=int, help='Server count on each machine.')
parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.')
parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
......
......@@ -15,22 +15,3 @@ from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph, find_edges
if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
SERV.start()
sys.exit()
......@@ -5,6 +5,7 @@ import traceback
import atexit
import time
import os
import sys
from . import rpc
from .constants import MAX_QUEUE_SIZE
......@@ -63,6 +64,26 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_worker_threads: int
The number of threads in a worker process.
"""
if os.environ.get('DGL_ROLE', 'client') == 'server':
from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
serv.start()
sys.exit()
else:
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
......@@ -76,6 +97,8 @@ def initialize(ip_config, num_servers=1, num_workers=0,
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
......
......@@ -125,8 +125,18 @@ def main():
the same machine. By default, it is 1.')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers > 0, '--num_trainers must be a positive number.'
assert args.num_samplers >= 0
assert args.num_trainers is not None and args.num_trainers > 0, \
'--num_trainers must be a positive number.'
assert args.num_samplers is not None and args.num_samplers >= 0, \
'--num_samplers must be a non-negative number.'
assert args.num_servers is not None and args.num_servers > 0, \
'--num_servers must be a positive number.'
assert args.num_server_threads > 0, '--num_server_threads must be a positive number.'
assert args.workspace is not None, 'A user has to specify a workspace with --workspace.'
assert args.part_config is not None, \
'A user has to specify a partition configuration file with --part_config.'
assert args.ip_config is not None, \
'A user has to specify an IP configuration file with --ip_config.'
udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment