Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
bb8cf71b
"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "10cceae9cc7b0bc453565a76f87602b1c824ea19"
Commit
bb8cf71b
authored
Jul 18, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
429226fb
f5349823
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
13 deletions
+33
-13
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+9
-4
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+9
-4
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+7
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+8
-3
No files found.
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
bb8cf71b
...
@@ -2,7 +2,12 @@ import pytest
...
@@ -2,7 +2,12 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
lightop
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
from
transformer_engine.pytorch
import
get_device_compute_capability
from
transformer_engine.pytorch
import
get_device_compute_capability
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
...
@@ -197,7 +202,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -197,7 +202,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
...
@@ -380,7 +385,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -380,7 +385,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
output_dtype
=
dx_dtype
...
@@ -564,7 +569,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -564,7 +569,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
accumulate
,
[
block_len
,
block_len
],
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
bb8cf71b
...
@@ -8,7 +8,12 @@ from typing import Iterable, Optional, Tuple, Union, List
...
@@ -8,7 +8,12 @@ from typing import Iterable, Optional, Tuple, Union, List
import
os
import
os
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
lightop
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
...
@@ -86,7 +91,7 @@ def general_gemm(
...
@@ -86,7 +91,7 @@ def general_gemm(
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
...
@@ -106,7 +111,7 @@ def general_gemm(
...
@@ -106,7 +111,7 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
...
@@ -126,7 +131,7 @@ def general_gemm(
...
@@ -126,7 +131,7 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
bb8cf71b
...
@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
lightop
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -610,7 +615,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -610,7 +615,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
output
,
config
=
w8a8_block_int8_matmul
(
output
,
config
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
bb8cf71b
...
@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
lightop
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -464,7 +469,7 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
...
@@ -464,7 +469,7 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
):
for
i
in
range
(
len
(
C_list
)):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
assert
C_list
[
i
]
is
not
None
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
,
output_dtype
=
output_dtype
,
...
@@ -663,7 +668,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -663,7 +668,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment