Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
e2cc2fc4
Commit
e2cc2fc4
authored
Jan 21, 2026
by
wenjh
Browse files
Merge branch 'develop_v2.10' into release_v2.10
parents
96a104d5
59b49b47
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
22 deletions
+33
-22
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+7
-5
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+9
-6
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+10
-6
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+7
-5
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
e2cc2fc4
...
@@ -10,11 +10,13 @@ import functools
...
@@ -10,11 +10,13 @@ import functools
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
warnings
import
warnings
try
:
enable_lightop
=
os
.
getenv
(
"NVTE_USE_LIGHTOP"
,
"false"
).
strip
().
lower
()
in
[
"true"
,
"1"
]
import
lightop
if
enable_lightop
:
enable_lightop
=
True
try
:
except
ImportError
:
import
lightop
enable_lightop
=
False
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
...
...
transformer_engine/pytorch/module/_common.py
View file @
e2cc2fc4
...
@@ -16,12 +16,15 @@ from ..export import is_in_onnx_export_mode
...
@@ -16,12 +16,15 @@ from ..export import is_in_onnx_export_mode
from
..utils
import
get_default_init_method
from
..utils
import
get_default_init_method
import
warnings
import
warnings
try
:
import
os
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
os
.
getenv
(
"NVTE_USE_LIGHTOP"
,
"false"
).
strip
().
lower
()
in
[
"true"
,
"1"
]
enable_lightop
=
True
if
enable_lightop
:
except
ImportError
:
try
:
enable_lightop
=
False
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
fwd_normalization_funcs
=
{
fwd_normalization_funcs
=
{
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
e2cc2fc4
...
@@ -80,12 +80,16 @@ from ..cpp_extensions import (
...
@@ -80,12 +80,16 @@ from ..cpp_extensions import (
general_gemm
,
general_gemm
,
)
)
import
warnings
import
warnings
try
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
os
.
getenv
(
"NVTE_USE_LIGHTOP"
,
"false"
).
strip
().
lower
()
in
[
"true"
,
"1"
]
enable_lightop
=
True
except
ImportError
:
if
enable_lightop
:
enable_lightop
=
False
try
:
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
warnings
.
warn
(
"Failed to import lightop module. Falling back to alternative implementation."
,
UserWarning
)
__all__
=
[
"LayerNormLinear"
]
__all__
=
[
"LayerNormLinear"
]
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
e2cc2fc4
...
@@ -88,11 +88,13 @@ from ..cpp_extensions import (
...
@@ -88,11 +88,13 @@ from ..cpp_extensions import (
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
try
:
enable_lightop
=
os
.
getenv
(
"NVTE_USE_LIGHTOP"
,
"false"
).
strip
().
lower
()
in
[
"true"
,
"1"
]
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
if
enable_lightop
:
enable_lightop
=
True
try
:
except
ImportError
:
from
lightop
import
rmsnorm_forward
,
rmsnorm_backward
enable_lightop
=
False
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
__all__
=
[
"LayerNormMLP"
]
__all__
=
[
"LayerNormMLP"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment