Commit cbff8d34 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决custom allreduce在dp情况下的其服务错误问题

parent 47067fcc
...@@ -142,7 +142,11 @@ class CustomAllreduce: ...@@ -142,7 +142,11 @@ class CustomAllreduce:
else: else:
device_ids = list(range(cuda_device_count_stateless())) device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index] if (world_size == len(device_ids)):
physical_device_id = device_ids[device.index % world_size]
else:
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [ gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment