Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7640a8d4
Commit
7640a8d4
authored
Jun 20, 2025
by
yuguo
Browse files
[DCU] fix megatron MOE int train issues
parent
d6c32078
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
35 deletions
+42
-35
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+34
-33
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+7
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+1
-0
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
7640a8d4
...
@@ -198,9 +198,36 @@ def general_grouped_gemm(
...
@@ -198,9 +198,36 @@ def general_grouped_gemm(
transa
=
layout
[
0
]
==
"T"
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
transb
=
layout
[
1
]
==
"T"
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
# Use bfloat16 as default bias_dtype
gelu_input
=
empty_tensors
out_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
D_dtype
is
None
else
D_dtype
sm_count
=
get_sm_count
()
if
grad
and
use_bias
:
grad_bias
=
[
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
]
else
:
grad_bias
=
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
if
use_bias
:
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
else
:
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
if
gelu
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
bias_dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
# this should differ with respect to single output
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
[
0
],
Float8BlockwiseQTensorBase
)):
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
[
0
],
Float8BlockwiseQTensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
use_bias
,
"Bias not supported with int8 simulation groupgemm."
if
layout
==
"TN"
:
if
layout
==
"TN"
:
qx_data
=
[
qx_data
=
[
...
@@ -215,11 +242,11 @@ def general_grouped_gemm(
...
@@ -215,11 +242,11 @@ def general_grouped_gemm(
num_gemms
=
len
(
A
)
num_gemms
=
len
(
A
)
seq_len
=
sum
(
m_splits
)
//
num_gemms
seq_len
=
sum
(
m_splits
)
//
num_gemms
out
[
0
]
,
_
=
w8a8_block_int8_matmul_batched
(
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
None
,
None
return
out
,
bias
,
gelu_input
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
qdout_data
=
[
qdout_data
=
[
...
@@ -234,11 +261,11 @@ def general_grouped_gemm(
...
@@ -234,11 +261,11 @@ def general_grouped_gemm(
num_gemms
=
len
(
A
)
num_gemms
=
len
(
A
)
seq_len
=
sum
(
m_splits
)
//
num_gemms
seq_len
=
sum
(
m_splits
)
//
num_gemms
out
[
0
]
,
_
=
w8a8_block_int8_matmul_batched
(
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
None
,
None
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
qdout_data
=
[
qdout_data
=
[
...
@@ -250,41 +277,15 @@ def general_grouped_gemm(
...
@@ -250,41 +277,15 @@ def general_grouped_gemm(
ref_scales_dout
=
[
b
.
_columnwise_scale_inv
for
b
in
B
]
ref_scales_dout
=
[
b
.
_columnwise_scale_inv
for
b
in
B
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
out
,
_
=
w8a8_block_int8_matmul_wgrad_batched
(
out
=
w8a8_block_int8_matmul_wgrad_batched
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
[
128
,
128
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
128
,
128
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
None
,
None
return
out
,
bias
,
gelu_input
else
:
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
# Use bfloat16 as default bias_dtype
gelu_input
=
empty_tensors
out_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
D_dtype
is
None
else
D_dtype
sm_count
=
get_sm_count
()
if
grad
and
use_bias
:
grad_bias
=
[
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
]
else
:
grad_bias
=
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
if
use_bias
:
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
else
:
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
if
gelu
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
bias_dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
# this should differ with respect to single output
bias
=
tex
.
te_general_grouped_gemm
(
bias
=
tex
.
te_general_grouped_gemm
(
A
,
A
,
transa
,
transa
,
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
7640a8d4
...
@@ -524,13 +524,18 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -524,13 +524,18 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
weight_b
=
[
weight
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_b
=
[
weight
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_scale_b
=
[
weight_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_scale_b
=
[
weight_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
N
,
K
=
weight
.
shape
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
output
=
torch
.
stack
(
output
).
contiguous
()
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
torch_output
=
torch_output
.
view
(
-
1
,
torch_output
.
size
(
-
1
))
# print(f"zhenggf, torch_output:{torch_output.shape}")
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
output
=
w8a8_block_int8_matmul_batched
(
output
=
w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
output
.
view
(
batch
,
*
C_shape
),
block_size
,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
7640a8d4
...
@@ -568,6 +568,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -568,6 +568,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
output
=
torch
.
stack
(
output
).
contiguous
()
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"triton 精度检查不合格!!!"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment