Unverified Commit bfbbefa7 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a bug in multiprocessing sampling. (#2826)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-73-81.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 48b9ecd2
......@@ -113,8 +113,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a
specify relative paths to the path of the workspace.
The command below launches one training process on each machine and each training process has 4 sampling processes.
**Note**: There is a known bug in Python 3.8. The training process hangs when running multiple sampling processes for each training process.
Please set the number of sampling processes to 0 if you are using Python 3.8.
```bash
python3 ~/workspace/dgl/tools/launch.py \
......
......@@ -109,8 +109,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a
specify relative paths to the path of the workspace.
The command below launches one training process on each machine and each training process has 4 sampling processes.
**Note**: There is a known bug in Python 3.8. The training process hangs when running multiple sampling processes for each training process.
Please set the number of sampling processes to 0 if you are using Python 3.8.
```bash
python3 ~/workspace/dgl/tools/launch.py \
......
......@@ -118,9 +118,8 @@ class DistDataLoader:
self.collate_fn = collate_fn
self.current_pos = 0
if self.pool is not None:
self.m = mp.Manager()
self.barrier = self.m.Barrier(self.num_workers)
self.queue = self.m.Queue(maxsize=queue_size)
m = mp.Manager()
self.queue = m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
self.drop_last = drop_last
......@@ -141,9 +140,10 @@ class DistDataLoader:
if self.pool is not None:
results = []
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(self.barrier, self.name, self.collate_fn, self.queue)))
init_fn, args=(barrier, self.name, self.collate_fn, self.queue)))
for res in results:
res.get()
......@@ -153,8 +153,11 @@ class DistDataLoader:
self.pool, self.num_workers = get_sampler_pool()
if self.pool is not None:
results = []
# Here we need to create the manager and barrier again.
m = mp.Manager()
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.barrier, self.name,)))
results.append(self.pool.apply_async(cleanup_fn, args=(barrier, self.name,)))
for res in results:
res.get()
......
......@@ -151,11 +151,6 @@ def main():
udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.")
if sys.version_info.major and sys.version_info.minor >= 8:
if args.num_samplers > 0:
print('WARNING! DGL does not support multiple sampler processes in Python>=3.8. '
+ 'Set the number of sampler processes to 0.')
args.num_samplers = 0
submit_jobs(args, udf_command)
def signal_handler(signal, frame):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment