Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4f79b7a9
Commit
4f79b7a9
authored
Apr 08, 2025
by
panning
Browse files
add lightop rmsnorm
parent
a207db1d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
18 deletions
+38
-18
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+19
-10
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+19
-8
No files found.
transformer_engine/pytorch/module/_common.py
View file @
4f79b7a9
...
@@ -15,6 +15,13 @@ from .. import cpp_extensions as tex
...
@@ -15,6 +15,13 @@ from .. import cpp_extensions as tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..utils
import
get_default_init_method
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.float8_tensor
import
Float8Tensor
import
warnings
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
...
@@ -81,7 +88,9 @@ def apply_normalization(
...
@@ -81,7 +88,9 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
):
return
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
else
:
return
normalization_func
(
return
normalization_func
(
*
inputs
,
*
inputs
,
eps
,
eps
,
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
4f79b7a9
...
@@ -61,6 +61,13 @@ from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
...
@@ -61,6 +61,13 @@ from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
general_gemm
,
general_gemm
,
)
)
import
warnings
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
__all__
=
[
"LayerNormLinear"
]
__all__
=
[
"LayerNormLinear"
]
...
@@ -757,6 +764,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -757,6 +764,10 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
dgrad
=
dgrad
.
reshape
(
inputmat
.
size
())
dgrad
=
dgrad
.
reshape
(
inputmat
.
size
())
elif
ctx
.
normalization
==
"RMSNorm"
:
elif
ctx
.
normalization
==
"RMSNorm"
:
if
enable_lightop
:
dgrad
,
dgamma
=
rmsnorm_backward
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
)
else
:
dgrad
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad
,
dgrad
,
inputmat
,
inputmat
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment