Unverified Commit de6afe24 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix high-precision dtype for MXFP8 AG (#2058)



* Fix high-precision dtype for MXFP8 AG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 077e26c3
...@@ -1225,14 +1225,12 @@ def _all_gather_mxfp8( ...@@ -1225,14 +1225,12 @@ def _all_gather_mxfp8(
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.size() in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else: else:
raise ValueError("Got MXFP8 input tensor without any data") raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16 dtype = torch.bfloat16 # Guess high-precision dtype.
else: else:
raise ValueError( raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment