Commit 04ef76dd authored by yuguo's avatar yuguo
Browse files

[DCU] remove channelwise int8 group gemm

parent 650cb815
...@@ -643,166 +643,7 @@ def general_grouped_gemm( ...@@ -643,166 +643,7 @@ def general_grouped_gemm(
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)): if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert False, "Unsupported channelwise in int8 simulation fp8"
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list = []
w_data_list = []
scales_x_list = []
scales_w_list = []
for b in B:
x_int8, x_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qx_data = torch.stack(qx_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
y_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
y_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qx_data.view(-1, qx_data.size(-1)),
transb,
y_int32.view(-1, y_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NN":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list = []
w_data_list = []
scales_dout_list = []
scales_w_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qdout_data = torch.stack(qdout_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
dx_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
dx_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dx_int32.view(-1, dx_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i])
out[0] = out[0].view(-1, out[0].size(-1))
return out, bias, gelu_input
elif layout == "NT":
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
scales_x_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
num_gemms = len(A)
qdout_data = torch.stack(qdout_data_list).contiguous()
qx_data = torch.stack(qx_data_list).contiguous()
dw_int32 = torch.empty((num_gemms, qdout_data.size(-1), qx_data.size(-1)), dtype=torch.int32, device="cuda")
dw_int32 = tex.generic_batchgemm(
qx_data.view(-1, qx_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dw_int32.view(-1, dw_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
None,
bias_dtype,
gelu,
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
if TE_DType_To_Torch[out_dtype] is torch.bfloat16:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i])
else:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i])
return out, bias, gelu_input
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
bias = tex.te_general_grouped_gemm( bias = tex.te_general_grouped_gemm(
A, A,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment