Commit 1c965000 authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Fix gather for SageMaker model parallel

parent 4e0410e9
......@@ -1021,6 +1021,7 @@ if is_sagemaker_mp_enabled():
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment