Commit cbff8d34 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决custom allreduce在dp情况下的其服务错误问题

parent 47067fcc
......@@ -142,7 +142,11 @@ class CustomAllreduce:
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
if (world_size == len(device_ids)):
physical_device_id = device_ids[device.index % world_size]
else:
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment