Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8113d9e0
"tests/vscode:/vscode.git/clone" did not exist on "490a5f41ada5788bc6dd94ba54ab024e465e0ec6"
Commit
8113d9e0
authored
Jul 18, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.4' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
93ecbc82
d9847b6d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+9
-6
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
8113d9e0
...
...
@@ -32,6 +32,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transB
,
channelwise_dequantize_transA_add
,
channelwise_dequantize_transA_float_add
)
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
...
...
@@ -91,7 +92,7 @@ def general_gemm(
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
...
...
@@ -111,7 +112,7 @@ def general_gemm(
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_
block_len
!=
128
or
not
enable_lightop
:
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
...
...
@@ -131,7 +132,7 @@ def general_gemm(
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_
block_len
!=
128
or
not
enable_lightop
:
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
...
...
@@ -199,7 +200,7 @@ def general_gemm(
x_int8
,
transb
,
None
,
quantization_params
,
None
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
...
...
@@ -225,7 +226,7 @@ def general_gemm(
dy_int8
,
transb
,
None
,
quantization_params
,
None
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
...
...
@@ -241,6 +242,7 @@ def general_gemm(
return
dx
,
None
,
None
,
None
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
...
@@ -250,7 +252,7 @@ def general_gemm(
dy_int8
,
transb
,
None
,
quantization_params
,
None
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
...
...
@@ -524,6 +526,7 @@ def general_grouped_gemm(
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
qdout_data_list
=
[]
qx_data_list
=
[]
scales_dout_list
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment