Unverified Commit 8311b07f authored by mlmz's avatar mlmz Committed by GitHub
Browse files

Fix: Ensure tensors for dist.broadcast match NCCL backend device (#5322)

parent c1380257
...@@ -848,31 +848,34 @@ def broadcast_pyobj( ...@@ -848,31 +848,34 @@ def broadcast_pyobj(
src: int = 0, src: int = 0,
): ):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[broadcast_pyobj] rank={rank}, device={device}")
if rank == 0: if rank == 0:
if len(data) == 0: if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.broadcast(tensor_size, src=src, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
tensor_data = torch.ByteTensor( tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8) np.frombuffer(serialized_data, dtype=np.uint8)
) ).to(device)
tensor_size = torch.tensor([size], dtype=torch.long) tensor_size = torch.tensor([size], dtype=torch.long, device=device)
dist.broadcast(tensor_size, src=src, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
dist.broadcast(tensor_data, src=src, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group)
return data return data
else: else:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.broadcast(tensor_size, src=src, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
size = tensor_size.item() size = tensor_size.item()
if size == 0: if size == 0:
return [] return []
tensor_data = torch.empty(size, dtype=torch.uint8) tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
dist.broadcast(tensor_data, src=src, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group)
serialized_data = bytes(tensor_data.cpu().numpy()) serialized_data = bytes(tensor_data.cpu().numpy())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment