Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
251dcc7e
Commit
251dcc7e
authored
Jun 20, 2025
by
yuguo
Browse files
[DCU] fix megatron MOE int8 train bugs
parent
7640a8d4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
18 deletions
+118
-18
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+2
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+52
-8
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+64
-8
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
251dcc7e
...
...
@@ -11,7 +11,7 @@ import transformer_engine_torch as tex
from
..constants
import
TE_DType
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
,
w8a8_block_int8_matmul_wgrad_batched_native
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
@@ -277,7 +277,7 @@ def general_grouped_gemm(
ref_scales_dout
=
[
b
.
_columnwise_scale_inv
for
b
in
B
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
out
=
w8a8_block_int8_matmul_wgrad_batched
(
out
=
w8a8_block_int8_matmul_wgrad_batched
_native
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
128
,
128
],
output_dtype
=
out_dtype
)
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
251dcc7e
...
...
@@ -462,14 +462,58 @@ def w8a8_block_int8_matmul_batched(
assert
C
.
size
(
-
1
)
==
N
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
,
}
if
best_config
:
config
=
best_config
else
:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
251dcc7e
...
...
@@ -456,6 +456,18 @@ def w8a8_block_int8_matmul_wgrad(
return
C
,
config
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
,
best_config
=
best_config
)
return
C_list
def
w8a8_block_int8_matmul_wgrad_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
...
...
@@ -487,14 +499,58 @@ def w8a8_block_int8_matmul_wgrad_batched(
else
:
C
=
torch
.
stack
(
C_list
).
contiguous
()
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
,
}
if
best_config
:
config
=
best_config
else
:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment