Unverified Commit d693ec04 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Make torch TP composable with torch.compile (#2352)

parent 18ea841f
...@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel): ...@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
)._prepare_output_fn( )._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh output_layouts, use_local_output, mod, outputs, device_mesh
) )
# wait for the output to be ready return torch.distributed._functional_collectives.wait_tensor(outputs)
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
def tensor_parallel( def tensor_parallel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment