Commit 2e870ed9 authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 059d92e2
...@@ -570,7 +570,7 @@ def general_grouped_gemm( ...@@ -570,7 +570,7 @@ def general_grouped_gemm(
dw.view(-1, dw.size(-1)), dw.view(-1, dw.size(-1)),
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], out_dtype,
None, None,
bias_dtype, bias_dtype,
gelu, gelu,
......
...@@ -1182,7 +1182,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1182,7 +1182,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
): ):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
if isinstance(quantizer, Float8BlockQuantizer): if isinstance(quantizer, Float8BlockQuantizer) or (isinstance(quantizer, Float8CurrentScalingQuantizer) and IS_HIP_EXTENSION):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment