Commit 7640a8d4 authored by yuguo's avatar yuguo
Browse files

[DCU] fix megatron MOE int train issues

parent d6c32078
......@@ -198,9 +198,36 @@ def general_grouped_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
if layout == "TN":
qx_data = [
......@@ -215,11 +242,11 @@ def general_grouped_gemm(
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched(
out[0] = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype
)
return out, None, None
return out, bias, gelu_input
elif layout == "NN":
qdout_data = [
......@@ -234,11 +261,11 @@ def general_grouped_gemm(
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
out[0], _ = w8a8_block_int8_matmul_batched(
out[0] = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
output_dtype=out_dtype
)
return out, None, None
return out, bias, gelu_input
elif layout == "NT":
qdout_data = [
......@@ -250,41 +277,15 @@ def general_grouped_gemm(
ref_scales_dout = [b._columnwise_scale_inv for b in B]
ref_scales_x = [a._columnwise_scale_inv for a in A]
out, _ = w8a8_block_int8_matmul_wgrad_batched(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate [128, 128],
out = w8a8_block_int8_matmul_wgrad_batched(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
output_dtype=out_dtype
)
return out, None, None
return out, bias, gelu_input
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_grouped_gemm(
A,
transa,
......
......@@ -524,13 +524,18 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
weight_b = [weight.clone().contiguous() for i in range(batch)]
weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = [q_input.new_empty(C_shape, dtype=out_dtype) for i in range(batch)]
output = torch.stack(output).contiguous()
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
torch_output = torch_output.view(-1, torch_output.size(-1))
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b]
output = w8a8_block_int8_matmul_batched(
q_input_b, weight_b, x_scale_b, weight_scale_b, block_size,
q_input_b, weight_b, x_scale_b, weight_scale_b, output.view(batch, *C_shape), block_size,
output_dtype=out_dtype,
best_config=best_config
)
......
......@@ -568,6 +568,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output_dtype=out_dtype,
best_config=best_config
)
output = torch.stack(output).contiguous()
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment