Commit 32edae18 authored by yuguo's avatar yuguo
Browse files

[DCU] fix tensorwise int8 moe bugs

parent 0cf10d1c
......@@ -473,13 +473,13 @@ def general_grouped_gemm(
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list, scales_x_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -501,7 +501,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -514,7 +514,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
if layout == "NN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -536,7 +536,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -549,7 +549,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
qx_data_list, scales_x_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
......@@ -572,7 +572,7 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias[0],
bias_dtype,
gelu,
gelu_input[0],
......@@ -591,10 +591,10 @@ def general_grouped_gemm(
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qx_data_list = []
w_data_list = []
scales_x_list = []
......@@ -642,7 +642,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NN":
assert out_dtype is torch.bfloat16
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16
qdout_data_list = []
w_data_list = []
scales_dout_list = []
......@@ -690,7 +690,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
......@@ -730,7 +730,7 @@ def general_grouped_gemm(
use_split_accumulator,
)[0]
if out_dtype is torch.bfloat16:
if TE_DType_To_Torch[out_dtype] is torch.bfloat16:
if accumulate:
for i in num_gemms:
channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
......
......@@ -49,8 +49,10 @@ std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool tran
const size_t B1 = B_shape.data[B_shape.ndim - 1];
// Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")");
if (transa || transb) {
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")");
}
// Construct output dims
std::vector<size_t> ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment