Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5fcf30ba
Commit
5fcf30ba
authored
Jul 09, 2025
by
wenjh
Browse files
Fix int8 gemm nt and wgrad
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
9fe13a33
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+2
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+2
-2
No files found.
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
5fcf30ba
...
@@ -50,8 +50,8 @@ def get_full_tuning_space():
...
@@ -50,8 +50,8 @@ def get_full_tuning_space():
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton
.
Config
({
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
2
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
#
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 2,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
blockwise_fp8_block_len
,
'BLOCK_SIZE_K'
:
blockwise_fp8_block_len
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
'M'
,
'N'
,
'K'
],
# reset_to_zero=['c_ptr']
# reset_to_zero=['c_ptr']
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
5fcf30ba
...
@@ -50,8 +50,8 @@ def get_full_tuning_space():
...
@@ -50,8 +50,8 @@ def get_full_tuning_space():
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton
.
Config
({
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
2
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
#
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 2,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
blockwise_fp8_block_len
,
'BLOCK_SIZE_K'
:
blockwise_fp8_block_len
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
],
],
key
=
[
'M'
,
'N'
,
'K'
],
key
=
[
'M'
,
'N'
,
'K'
],
# reset_to_zero=['c_ptr']
# reset_to_zero=['c_ptr']
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment