Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0e886dab
Commit
0e886dab
authored
Jul 01, 2025
by
wenjh
Browse files
Merge develop_v2.4
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e56de127
b944277c
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
10 deletions
+13
-10
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+4
-3
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+4
-3
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+5
-4
No files found.
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
0e886dab
...
@@ -8,6 +8,7 @@ import triton
...
@@ -8,6 +8,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
import
pandas
as
pd
import
pandas
as
pd
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
...
@@ -557,7 +558,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -557,7 +558,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
best_config
:
Optional
[
dict
]
=
None
):
...
@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
best_config
:
Optional
[
dict
]
=
None
):
...
@@ -771,7 +772,7 @@ def main():
...
@@ -771,7 +772,7 @@ def main():
n_list
=
[
7168
]
n_list
=
[
7168
]
k_list
=
[
1152
]
k_list
=
[
1152
]
block_size
=
[
128
,
128
]
block_size
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]
out_dtype
=
torch
.
bfloat16
out_dtype
=
torch
.
bfloat16
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
0e886dab
...
@@ -8,6 +8,7 @@ import triton
...
@@ -8,6 +8,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
import
pandas
as
pd
import
pandas
as
pd
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper_b
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper_b
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
...
@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
best_config
:
Optional
[
dict
]
=
None
):
...
@@ -639,7 +640,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -639,7 +640,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
best_config
:
Optional
[
dict
]
=
None
):
...
@@ -821,7 +822,7 @@ def main():
...
@@ -821,7 +822,7 @@ def main():
n_list
=
[
7168
]
n_list
=
[
7168
]
k_list
=
[
1152
]
k_list
=
[
1152
]
block_size
=
[
128
,
128
]
block_size
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]
out_dtype
=
torch
.
bfloat16
out_dtype
=
torch
.
bfloat16
...
...
transformer_engine/pytorch/triton/per_token_group_quant.py
View file @
0e886dab
...
@@ -9,6 +9,7 @@ import triton.language as tl
...
@@ -9,6 +9,7 @@ import triton.language as tl
import
pandas
as
pd
import
pandas
as
pd
import
logging
import
logging
import
math
import
math
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
def
to_int8
(
tensor
:
torch
.
Tensor
):
def
to_int8
(
tensor
:
torch
.
Tensor
):
...
@@ -118,7 +119,7 @@ def _int8_gemm_helper(m: int,
...
@@ -118,7 +119,7 @@ def _int8_gemm_helper(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
best_config
:
Optional
[
list
]
=
None
):
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
# and per-output channel weight quantization.
...
@@ -143,7 +144,7 @@ def _int8_gemm_helper_b(m: int,
...
@@ -143,7 +144,7 @@ def _int8_gemm_helper_b(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
best_config
:
Optional
[
list
]
=
None
):
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
# and per-output channel weight quantization.
...
@@ -168,7 +169,7 @@ def _int8_gemm_helper_test(m: int,
...
@@ -168,7 +169,7 @@ def _int8_gemm_helper_test(m: int,
k
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
block_size
:
List
[
int
]
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
best_config
:
Optional
[
list
]
=
None
):
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
# and per-output channel weight quantization.
...
@@ -197,7 +198,7 @@ def main():
...
@@ -197,7 +198,7 @@ def main():
m_list
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
m_list
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
]
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
]
block_size
=
[
128
,
128
]
block_size
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]
out_dtype
=
torch
.
bfloat16
out_dtype
=
torch
.
bfloat16
_n
=
[]
_n
=
[]
_k
=
[]
_k
=
[]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment