Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
6c9dc19d
Commit
6c9dc19d
authored
Jan 23, 2026
by
wenjh
Browse files
Refine the constraints while using lightop in gemm.py
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
59b49b47
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+4
-4
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
6c9dc19d
...
@@ -53,7 +53,7 @@ __all__ = [
...
@@ -53,7 +53,7 @@ __all__ = [
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
out_dtype
=
torch
.
float16
):
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
out_dtype
=
torch
.
float16
):
for
i
in
range
(
len
(
C_list
)):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
assert
C_list
[
i
]
is
not
None
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
if
enable_lightop
and
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
out_dtype
,
"TN"
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
out_dtype
,
"TN"
)
)
...
@@ -80,7 +80,7 @@ def w8a8_int8_general_gemm(
...
@@ -80,7 +80,7 @@ def w8a8_int8_general_gemm(
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
if
enable_lightop
and
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
...
@@ -89,7 +89,7 @@ def w8a8_int8_general_gemm(
...
@@ -89,7 +89,7 @@ def w8a8_int8_general_gemm(
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with NN layout"
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with NN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with NN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with NN layout"
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
if
enable_lightop
and
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_dout
=
B
.
_rowwise_scale_inv
...
@@ -108,7 +108,7 @@ def w8a8_int8_general_gemm(
...
@@ -108,7 +108,7 @@ def w8a8_int8_general_gemm(
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
if
enable_lightop
and
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
and
((
out_dtype
is
torch
.
bfloat16
)
or
(
out_dtype
is
torch
.
float16
)):
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment