"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ebaea545ffddf2e9079994f2ea657a7fa5f358c"
Unverified Commit bfbbefa7 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a bug in multiprocessing sampling. (#2826)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-73-81.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 48b9ecd2
...@@ -113,8 +113,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a ...@@ -113,8 +113,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a
specify relative paths to the path of the workspace. specify relative paths to the path of the workspace.
The command below launches one training process on each machine and each training process has 4 sampling processes. The command below launches one training process on each machine and each training process has 4 sampling processes.
**Note**: There is a known bug in Python 3.8. The training process hangs when running multiple sampling processes for each training process.
Please set the number of sampling processes to 0 if you are using Python 3.8.
```bash ```bash
python3 ~/workspace/dgl/tools/launch.py \ python3 ~/workspace/dgl/tools/launch.py \
......
...@@ -109,8 +109,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a ...@@ -109,8 +109,6 @@ DGL provides a script to launch the training job in the cluster. `part_config` a
specify relative paths to the path of the workspace. specify relative paths to the path of the workspace.
The command below launches one training process on each machine and each training process has 4 sampling processes. The command below launches one training process on each machine and each training process has 4 sampling processes.
**Note**: There is a known bug in Python 3.8. The training process hangs when running multiple sampling processes for each training process.
Please set the number of sampling processes to 0 if you are using Python 3.8.
```bash ```bash
python3 ~/workspace/dgl/tools/launch.py \ python3 ~/workspace/dgl/tools/launch.py \
......
...@@ -118,9 +118,8 @@ class DistDataLoader: ...@@ -118,9 +118,8 @@ class DistDataLoader:
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.current_pos = 0 self.current_pos = 0
if self.pool is not None: if self.pool is not None:
self.m = mp.Manager() m = mp.Manager()
self.barrier = self.m.Barrier(self.num_workers) self.queue = m.Queue(maxsize=queue_size)
self.queue = self.m.Queue(maxsize=queue_size)
else: else:
self.queue = Queue(maxsize=queue_size) self.queue = Queue(maxsize=queue_size)
self.drop_last = drop_last self.drop_last = drop_last
...@@ -141,9 +140,10 @@ class DistDataLoader: ...@@ -141,9 +140,10 @@ class DistDataLoader:
if self.pool is not None: if self.pool is not None:
results = [] results = []
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers): for _ in range(self.num_workers):
results.append(self.pool.apply_async( results.append(self.pool.apply_async(
init_fn, args=(self.barrier, self.name, self.collate_fn, self.queue))) init_fn, args=(barrier, self.name, self.collate_fn, self.queue)))
for res in results: for res in results:
res.get() res.get()
...@@ -153,8 +153,11 @@ class DistDataLoader: ...@@ -153,8 +153,11 @@ class DistDataLoader:
self.pool, self.num_workers = get_sampler_pool() self.pool, self.num_workers = get_sampler_pool()
if self.pool is not None: if self.pool is not None:
results = [] results = []
# Here we need to create the manager and barrier again.
m = mp.Manager()
barrier = m.Barrier(self.num_workers)
for _ in range(self.num_workers): for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.barrier, self.name,))) results.append(self.pool.apply_async(cleanup_fn, args=(barrier, self.name,)))
for res in results: for res in results:
res.get() res.get()
......
...@@ -151,11 +151,6 @@ def main(): ...@@ -151,11 +151,6 @@ def main():
udf_command = str(udf_command[0]) udf_command = str(udf_command[0])
if 'python' not in udf_command: if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.") raise RuntimeError("DGL launching script can only support Python executable file.")
if sys.version_info.major and sys.version_info.minor >= 8:
if args.num_samplers > 0:
print('WARNING! DGL does not support multiple sampler processes in Python>=3.8. '
+ 'Set the number of sampler processes to 0.')
args.num_samplers = 0
submit_jobs(args, udf_command) submit_jobs(args, udf_command)
def signal_handler(signal, frame): def signal_handler(signal, frame):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment