Commit 2e870ed9 authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 059d92e2
......@@ -570,7 +570,7 @@ def general_grouped_gemm(
dw.view(-1, dw.size(-1)),
num_gemms,
None,
TE_DType[out_dtype],
out_dtype,
None,
bias_dtype,
gelu,
......
......@@ -1182,7 +1182,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
if isinstance(quantizer, Float8BlockQuantizer):
if isinstance(quantizer, Float8BlockQuantizer) or (isinstance(quantizer, Float8CurrentScalingQuantizer) and IS_HIP_EXTENSION):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment