Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
803be71d
Commit
803be71d
authored
Sep 18, 2025
by
wenjh
Browse files
Fix w8a8 lightop restriction
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
d81f8119
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+8
-7
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
803be71d
...
@@ -46,17 +46,18 @@ __all__ = [
...
@@ -46,17 +46,18 @@ __all__ = [
"batchgemm"
,
"batchgemm"
,
]
]
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
out
put
_dtype
=
torch
.
float16
):
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
out_dtype
=
torch
.
float16
):
for
i
in
range
(
len
(
C_list
)):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
assert
C_list
[
i
]
is
not
None
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
))
:
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
out
put
_dtype
,
"TN"
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
out_dtype
,
"TN"
)
)
else
:
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
C_list
[
i
],
_
=
w8a8_block_int8_matmul_wgrad
(
C_list
[
i
],
_
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
out
put
_dtype
,
out_dtype
,
None
None
)
)
return
C_list
return
C_list
...
@@ -75,7 +76,7 @@ def w8a8_int8_general_gemm(
...
@@ -75,7 +76,7 @@ def w8a8_int8_general_gemm(
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
))
:
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
...
@@ -84,7 +85,7 @@ def w8a8_int8_general_gemm(
...
@@ -84,7 +85,7 @@ def w8a8_int8_general_gemm(
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with NN layout"
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with NN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with NN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with NN layout"
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
))
:
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_dout
=
B
.
_rowwise_scale_inv
...
@@ -103,7 +104,7 @@ def w8a8_int8_general_gemm(
...
@@ -103,7 +104,7 @@ def w8a8_int8_general_gemm(
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
))
:
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment