Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
68d6c506
Commit
68d6c506
authored
Aug 07, 2025
by
yuguo
Browse files
[DCU] fix channelwise train accumulate bug
parent
4a013bd5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+4
-4
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+2
-2
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
68d6c506
...
...
@@ -267,12 +267,12 @@ def general_gemm(
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
out
=
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
else
:
if
accumulate
:
out
=
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
return
out
,
None
,
None
,
None
...
...
@@ -572,14 +572,14 @@ def general_grouped_gemm(
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
channelwise_dequantize_transA_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
else
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
])
else
:
if
accumulate
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_float_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
channelwise_dequantize_transA_float_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
else
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_float
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
])
...
...
transformer_engine/pytorch/triton/per_token_group_quant.py
View file @
68d6c506
...
...
@@ -331,12 +331,12 @@ def channelwise_dequantize_transA_float(A, B, C):
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transA_add
(
A
,
B
,
C
,
D
):
out_scales
=
A
.
T
*
B
return
(
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
+
D
D
.
add_
(
(
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transA_float_add
(
A
,
B
,
C
,
D
):
out_scales
=
A
.
T
*
B
return
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)
+
D
D
.
add_
(
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transB
(
A
,
B
,
C
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment