Commit 32edae18 authored by yuguo's avatar yuguo
Browse files

[DCU] fix tensorwise int8 moe bugs

parent 0cf10d1c
...@@ -473,13 +473,13 @@ def general_grouped_gemm( ...@@ -473,13 +473,13 @@ def general_grouped_gemm(
else: else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise: if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation" assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN": if layout == "TN":
assert out_dtype is torch.bfloat16 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list, scales_x_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B] qx_data_list, scales_x_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A] w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
...@@ -501,7 +501,7 @@ def general_grouped_gemm( ...@@ -501,7 +501,7 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], TE_DType[out_dtype],
bias, bias[0],
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], gelu_input[0],
...@@ -514,7 +514,7 @@ def general_grouped_gemm( ...@@ -514,7 +514,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
if layout == "NN": if layout == "NN":
assert out_dtype is torch.bfloat16 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B] qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A] w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
...@@ -536,7 +536,7 @@ def general_grouped_gemm( ...@@ -536,7 +536,7 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], TE_DType[out_dtype],
bias, bias[0],
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], gelu_input[0],
...@@ -549,7 +549,7 @@ def general_grouped_gemm( ...@@ -549,7 +549,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B] qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
qx_data_list, scales_x_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A] qx_data_list, scales_x_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
...@@ -572,7 +572,7 @@ def general_grouped_gemm( ...@@ -572,7 +572,7 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], TE_DType[out_dtype],
bias, bias[0],
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], gelu_input[0],
...@@ -591,10 +591,10 @@ def general_grouped_gemm( ...@@ -591,10 +591,10 @@ def general_grouped_gemm(
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation" assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN": if layout == "TN":
assert out_dtype is torch.bfloat16 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list = [] qx_data_list = []
w_data_list = [] w_data_list = []
scales_x_list = [] scales_x_list = []
...@@ -642,7 +642,7 @@ def general_grouped_gemm( ...@@ -642,7 +642,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
elif layout == "NN": elif layout == "NN":
assert out_dtype is torch.bfloat16 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list = [] qdout_data_list = []
w_data_list = [] w_data_list = []
scales_dout_list = [] scales_dout_list = []
...@@ -690,7 +690,7 @@ def general_grouped_gemm( ...@@ -690,7 +690,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32 assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list = [] qdout_data_list = []
qx_data_list = [] qx_data_list = []
scales_dout_list = [] scales_dout_list = []
...@@ -730,7 +730,7 @@ def general_grouped_gemm( ...@@ -730,7 +730,7 @@ def general_grouped_gemm(
use_split_accumulator, use_split_accumulator,
)[0] )[0]
if out_dtype is torch.bfloat16: if TE_DType_To_Torch[out_dtype] is torch.bfloat16:
if accumulate: if accumulate:
for i in num_gemms: for i in num_gemms:
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i]) channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
......
...@@ -49,8 +49,10 @@ std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool tran ...@@ -49,8 +49,10 @@ std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool tran
const size_t B1 = B_shape.data[B_shape.ndim - 1]; const size_t B1 = B_shape.data[B_shape.ndim - 1];
// Check matrix dims // Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", if (transa || transb) {
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")");
}
// Construct output dims // Construct output dims
std::vector<size_t> ret; std::vector<size_t> ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment