Commit af196942 authored by yuguo's avatar yuguo
Browse files
parents fb6798f1 68d6c506
......@@ -267,12 +267,12 @@ def general_gemm(
)[0]
if out_dtype is torch.bfloat16:
if accumulate:
out = channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else:
if accumulate:
out = channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None
......@@ -572,14 +572,14 @@ def general_grouped_gemm(
if out_dtype is torch.bfloat16:
if accumulate:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i])
else:
if accumulate:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i])
......
......@@ -331,12 +331,12 @@ def channelwise_dequantize_transA_float(A, B, C):
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_add(A, B, C, D):
out_scales = A.T * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16) + D
D.add_((out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16))
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_float_add(A, B, C, D):
out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32) + D
D.add_(out_scales * C.to(dtype=torch.float32))
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transB(A, B, C):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment