Unverified Commit 43a013c3 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Fix Dtypes for Pynccl Wrapper (#33030)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent c25dbee4
......@@ -72,7 +72,8 @@ class ncclDataTypeEnum:
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
ncclFloat8e4m3 = 10
ncclNumTypes = 11
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
......@@ -92,9 +93,12 @@ class ncclDataTypeEnum:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
" float8e4m3."
)
......
......@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
a1q_scale = None
if is_nvfp4 and a1q_scale is not None:
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment