Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1036ccfe
Commit
1036ccfe
authored
Jul 17, 2025
by
wenjh
Browse files
Use lightop replace w8a8_mutmal_extension
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
00738a42
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
32 deletions
+23
-32
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+7
-10
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+7
-10
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+3
-4
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+6
-8
No files found.
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
1036ccfe
...
...
@@ -2,7 +2,7 @@ import pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
w8a8_matmul_extension
import
lightop
from
transformer_engine.pytorch
import
get_device_compute_capability
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
...
...
@@ -203,9 +203,8 @@ def cublas_gemm_fp8_blockwise_case_fw(
output_dtype
=
out_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
)
# print("int8 gemm output: ", y)
...
...
@@ -387,9 +386,8 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
output_dtype
=
dx_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
dx_dtype
)
# print("int8 gemm dx: ", y)
...
...
@@ -573,10 +571,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
output_dtype
=
dw_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
_wgrad
(
y
=
lightop
.
gemm_w8a8
_wgrad
_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
accumulate
,
[
block_len
,
block_len
],
dw_dtype
)
# print("int8 gemm dw: ",y)
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
1036ccfe
...
...
@@ -8,7 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
import
os
import
torch
import
transformer_engine_torch
as
tex
import
w8a8_matmul_extension
import
lightop
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
...
...
@@ -92,9 +92,8 @@ def general_gemm(
output_dtype
=
out_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
,
None
,
None
,
None
...
...
@@ -113,9 +112,8 @@ def general_gemm(
output_dtype
=
out_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
,
None
,
None
,
None
...
...
@@ -134,9 +132,8 @@ def general_gemm(
output_dtype
=
out_dtype
)
else
:
out
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
out
,
None
,
None
,
None
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
1036ccfe
...
...
@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
logging
import
w8a8_matmul_extension
import
lightop
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -617,9 +617,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config
=
best_config
)
else
:
output
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
output
=
lightop
.
gemm_w8a8_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
1036ccfe
...
...
@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
logging
import
w8a8_matmul_extension
import
lightop
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -471,9 +471,8 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
best_config
=
best_config
)
else
:
C_list
[
i
]
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
)
return
C_list
...
...
@@ -671,9 +670,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config
=
best_config
)
else
:
output
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
output
=
lightop
.
gemm_w8a8_wgrad_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
...
...
@@ -836,7 +834,7 @@ def main():
block_size
=
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]
out_dtype
=
torch
.
b
float16
out_dtype
=
torch
.
float16
_n
=
[]
_k
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment