Commit e722c4a9 authored by mshoeybi's avatar mshoeybi
Browse files

tested and woking

parent 107c29e8
......@@ -38,12 +38,25 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
long_tensor = None
if torch.distributed.get_rank() == rank:
long_tensor = torch.tensor(int_list, dtype=torch.int64,
device=torch.cuda.current_device())
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_tensor(size, torch.int64, tensor=long_tensor, rank=rank)
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment