Unverified Commit 8292bf32 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] remove deprecated arguments from initialize() (#4284)

parent 6e1be69a
......@@ -173,28 +173,19 @@ class CustomPool:
self.process_list[i].join()
def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket',
num_worker_threads=1):
def initialize(ip_config, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', num_worker_threads=1):
"""Initialize DGL's distributed module
This function initializes DGL's distributed module. It acts differently in server
or client modes. In the server mode, it runs the server code and never returns.
In the client mode, it builds connections with servers for communication and
creates worker processes for distributed sampling. `num_workers` specifies
the number of sampling worker processes per trainer process.
Users also have to provide the number of server processes on each machine in order
to connect to all the server processes in the cluster of machines correctly.
creates worker processes for distributed sampling.
Parameters
----------
ip_config: str
File path of ip_config file
num_servers : int
The number of server processes on each machine. This argument is deprecated in DGL 0.7.0.
num_workers: int
Number of worker process on each machine. The worker processes are used
for distributed sampling. This argument is deprecated in DGL 0.7.0.
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
......@@ -205,7 +196,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
Default: ``'socket'``
num_worker_threads: int
The number of threads in a worker process.
The number of OMP threads in each sampler process.
Note
----
......@@ -240,14 +231,8 @@ def initialize(ip_config, num_servers=1, num_workers=0,
serv.start()
sys.exit()
else:
if os.environ.get('DGL_NUM_SAMPLER') is not None:
num_workers = int(os.environ.get('DGL_NUM_SAMPLER'))
else:
num_workers = 0
if os.environ.get('DGL_NUM_SERVER') is not None:
num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else:
num_servers = 1
num_workers = int(os.environ.get('DGL_NUM_SAMPLER', 0))
num_servers = int(os.environ.get('DGL_NUM_SERVER', 1))
group_id = int(os.environ.get('DGL_GROUP_ID', 0))
rpc.reset()
global SAMPLER_POOL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment