Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b15412aa
Commit
b15412aa
authored
Sep 19, 2025
by
yuguo
Browse files
[DCU] fix
parent
803be71d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+17
-8
No files found.
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
b15412aa
...
...
@@ -80,6 +80,12 @@ from ..cpp_extensions import (
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
...debug.pytorch.debug_state
import
TEDebugState
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
__all__
=
[
"LayerNormMLP"
]
...
...
@@ -1264,14 +1270,17 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
zero_centered_gamma
,
)
elif
ctx
.
normalization
==
"RMSNorm"
:
dgrad
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
,
ctx
.
bwd_ln_sm_margin
,
ctx
.
zero_centered_gamma
,
)
if
enable_lightop
and
(
rsigma
.
dtype
is
torch
.
bfloat16
or
rsigma
.
dtype
is
torch
.
float16
):
dgrad
,
dgamma
=
rmsnorm_backward
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
)
else
:
dgrad
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
,
ctx
.
bwd_ln_sm_margin
,
ctx
.
zero_centered_gamma
,
)
dbeta
=
None
clear_tensor_data
(
mu
,
rsigma
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment