Unverified Commit de6afe24 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix high-precision dtype for MXFP8 AG (#2058)



* Fix high-precision dtype for MXFP8 AG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 077e26c3
......@@ -1225,14 +1225,12 @@ def _all_gather_mxfp8(
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16
dtype = torch.bfloat16 # Guess high-precision dtype.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment