Unverified Commit 5f9b2c62 authored by yiakwy-xpu-ml-framework-team's avatar yiakwy-xpu-ml-framework-team Committed by GitHub
Browse files

[ROCm] fix dtype (#4510)

parent 5493c334
...@@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight.shape[-1] layer.weight, layer.weight.shape[-1]
) )
weight_scale = weight_scale.t().contiguous() weight_scale = weight_scale.t().contiguous()
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
else: else:
# if cutlass not supported, we fall back to use torch._scaled_mm # if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight # which requires per tensor quantization on weight
qweight, weight_scale = input_to_float8(layer.weight) fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
# Update the layer with the new values. # Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight = Parameter(qweight.t(), requires_grad=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment