Commit 68d6c506 authored by yuguo's avatar yuguo
Browse files

[DCU] fix channelwise train accumulate bug

parent 4a013bd5
...@@ -267,12 +267,12 @@ def general_gemm( ...@@ -267,12 +267,12 @@ def general_gemm(
)[0] )[0]
if out_dtype is torch.bfloat16: if out_dtype is torch.bfloat16:
if accumulate: if accumulate:
out = channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out) channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else: else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32) out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else: else:
if accumulate: if accumulate:
out = channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out) channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else: else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32) out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None return out, None, None, None
...@@ -572,14 +572,14 @@ def general_grouped_gemm( ...@@ -572,14 +572,14 @@ def general_grouped_gemm(
if out_dtype is torch.bfloat16: if out_dtype is torch.bfloat16:
if accumulate: if accumulate:
for i in num_gemms: for i in num_gemms:
out[i] = channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i]) channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else: else:
for i in num_gemms: for i in num_gemms:
out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i]) out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i])
else: else:
if accumulate: if accumulate:
for i in num_gemms: for i in num_gemms:
out[i] = channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i]) channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else: else:
for i in num_gemms: for i in num_gemms:
out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i]) out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i])
......
...@@ -331,12 +331,12 @@ def channelwise_dequantize_transA_float(A, B, C): ...@@ -331,12 +331,12 @@ def channelwise_dequantize_transA_float(A, B, C):
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_add(A, B, C, D): def channelwise_dequantize_transA_add(A, B, C, D):
out_scales = A.T * B out_scales = A.T * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16) + D D.add_((out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16))
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_float_add(A, B, C, D): def channelwise_dequantize_transA_float_add(A, B, C, D):
out_scales = A.T * B out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32) + D D.add_(out_scales * C.to(dtype=torch.float32))
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transB(A, B, C): def channelwise_dequantize_transB(A, B, C):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment