Commit 7640a8d4 authored by yuguo's avatar yuguo
Browse files

[DCU] fix megatron MOE int train issues

parent d6c32078
...@@ -198,9 +198,36 @@ def general_grouped_gemm( ...@@ -198,9 +198,36 @@ def general_grouped_gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)): if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
if layout == "TN": if layout == "TN":
qx_data = [ qx_data = [
...@@ -215,11 +242,11 @@ def general_grouped_gemm( ...@@ -215,11 +242,11 @@ def general_grouped_gemm(
num_gemms = len(A) num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched( out[0] = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, None, None return out, bias, gelu_input
elif layout == "NN": elif layout == "NN":
qdout_data = [ qdout_data = [
...@@ -234,11 +261,11 @@ def general_grouped_gemm( ...@@ -234,11 +261,11 @@ def general_grouped_gemm(
num_gemms = len(A) num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched( out[0] = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, None, None return out, bias, gelu_input
elif layout == "NT": elif layout == "NT":
qdout_data = [ qdout_data = [
...@@ -250,41 +277,15 @@ def general_grouped_gemm( ...@@ -250,41 +277,15 @@ def general_grouped_gemm(
ref_scales_dout = [b._columnwise_scale_inv for b in B] ref_scales_dout = [b._columnwise_scale_inv for b in B]
ref_scales_x = [a._columnwise_scale_inv for a in A] ref_scales_x = [a._columnwise_scale_inv for a in A]
out, _ = w8a8_block_int8_matmul_wgrad_batched( out = w8a8_block_int8_matmul_wgrad_batched(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, None, None return out, bias, gelu_input
else: else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_grouped_gemm( bias = tex.te_general_grouped_gemm(
A, A,
transa, transa,
......
...@@ -524,13 +524,18 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -524,13 +524,18 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
weight_b = [weight.clone().contiguous() for i in range(batch)] weight_b = [weight.clone().contiguous() for i in range(batch)]
weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)] weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}") # print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = [q_input.new_empty(C_shape, dtype=out_dtype) for i in range(batch)]
output = torch.stack(output).contiguous()
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size) torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
torch_output = torch_output.view(-1, torch_output.size(-1))
# print(f"zhenggf, torch_output:{torch_output.shape}") # print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b] x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b]
output = w8a8_block_int8_matmul_batched( output = w8a8_block_int8_matmul_batched(
q_input_b, weight_b, x_scale_b, weight_scale_b, block_size, q_input_b, weight_b, x_scale_b, weight_scale_b, output.view(batch, *C_shape), block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
......
...@@ -568,6 +568,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -568,6 +568,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
output = torch.stack(output).contiguous()
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!") print("triton 精度检查不合格!!!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment