Commit 04ef76dd authored by yuguo's avatar yuguo
Browse files

[DCU] remove channelwise int8 group gemm

parent 650cb815
......@@ -643,166 +643,7 @@ def general_grouped_gemm(
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list = []
w_data_list = []
scales_x_list = []
scales_w_list = []
for b in B:
x_int8, x_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qx_data = torch.stack(qx_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
y_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
y_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qx_data.view(-1, qx_data.size(-1)),
transb,
y_int32.view(-1, y_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NN":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list = []
w_data_list = []
scales_dout_list = []
scales_w_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qdout_data = torch.stack(qdout_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
dx_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
dx_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dx_int32.view(-1, dx_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NT":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
scales_x_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
num_gemms = len(A)
qdout_data = torch.stack(qdout_data_list).contiguous()
qx_data = torch.stack(qx_data_list).contiguous()
dw_int32 = torch.empty((num_gemms, qdout_data.size(-1), qx_data.size(-1)), dtype=torch.int32, device="cuda")
dw_int32 = tex.generic_batchgemm(
qx_data.view(-1, qx_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dw_int32.view(-1, dw_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
if TE_DType_To_Torch[out_dtype] is torch.bfloat16:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i])
else:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i])
return out, bias, gelu_input
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
assert False, "Unsupported channelwise in int8 simulation fp8"
bias = tex.te_general_grouped_gemm(
A,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment