Unverified Commit eb1771ef authored by Prajwal Kailas's avatar Prajwal Kailas Committed by GitHub
Browse files

Check for mapping/dict in distributed_concat function (#21500)

check for mapping/dict in distributed_concat function

Co-authored-by: prajwal967 <user.email>
parent 7e51a441
......@@ -189,6 +189,8 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
if isinstance(tensor, Mapping):
return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()})
tensor = atleast_1d(tensor).contiguous()
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment