Commit ef5ebdbf authored by zhuwenwen's avatar zhuwenwen
Browse files

解决custom allreduce在dp情况下的其服务错误问题

parent 5e19613e
......@@ -128,7 +128,11 @@ class CustomAllreduce:
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
if (world_size == len(device_ids)):
physical_device_id = device_ids[device.index % world_size]
else:
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment