"vscode:/vscode.git/clone" did not exist on "48ddb02b79d7e22e2eefbf5294bf70de50afd1b2"
Unverified Commit 43a013c3 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Fix Dtypes for Pynccl Wrapper (#33030)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent c25dbee4
...@@ -72,7 +72,8 @@ class ncclDataTypeEnum: ...@@ -72,7 +72,8 @@ class ncclDataTypeEnum:
ncclFloat64 = 8 ncclFloat64 = 8
ncclDouble = 8 ncclDouble = 8
ncclBfloat16 = 9 ncclBfloat16 = 9
ncclNumTypes = 10 ncclFloat8e4m3 = 10
ncclNumTypes = 11
@classmethod @classmethod
def from_torch(cls, dtype: torch.dtype) -> int: def from_torch(cls, dtype: torch.dtype) -> int:
...@@ -92,9 +93,12 @@ class ncclDataTypeEnum: ...@@ -92,9 +93,12 @@ class ncclDataTypeEnum:
return cls.ncclFloat64 return cls.ncclFloat64
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return cls.ncclBfloat16 return cls.ncclBfloat16
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
raise ValueError( raise ValueError(
f"Unsupported dtype {dtype}: should be one of " f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16." f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
" float8e4m3."
) )
......
...@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin ...@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
a1q_scale = None a1q_scale = None
if is_nvfp4 and a1q_scale is not None: if is_nvfp4 and a1q_scale is not None:
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale) a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment