Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
148b5bea
Commit
148b5bea
authored
Jul 15, 2025
by
wenjh
Browse files
Fix pytorch module import error
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
793e0103
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
5 deletions
+4
-5
transformer_engine/__init__.py
transformer_engine/__init__.py
+1
-4
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+3
-1
No files found.
transformer_engine/__init__.py
View file @
148b5bea
...
@@ -12,10 +12,7 @@ import transformer_engine.common
...
@@ -12,10 +12,7 @@ import transformer_engine.common
try
:
try
:
from
.
import
pytorch
from
.
import
pytorch
except
ImportError
as
e
:
except
ImportError
as
e
:
try
:
pass
from
.
import
pytorch
except
ImportError
as
e
:
pass
try
:
try
:
from
.
import
jax
from
.
import
jax
...
...
transformer_engine/pytorch/utils.py
View file @
148b5bea
...
@@ -13,7 +13,6 @@ import torch
...
@@ -13,7 +13,6 @@ import torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.
import
torch_version
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
...
@@ -549,6 +548,7 @@ def round_up_to_nearest_multiple(value, multiple):
...
@@ -549,6 +548,7 @@ def round_up_to_nearest_multiple(value, multiple):
def
needs_quantized_gemm
(
obj
,
rowwise
=
True
):
def
needs_quantized_gemm
(
obj
,
rowwise
=
True
):
"""Used to check if obj will need quantized gemm or normal gemm."""
"""Used to check if obj will need quantized gemm or normal gemm."""
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
if
isinstance
(
obj
,
DebugQuantizedTensor
):
if
isinstance
(
obj
,
DebugQuantizedTensor
):
return
type
(
obj
.
get_tensor
(
not
rowwise
))
not
in
[
# pylint: disable=unidiomatic-typecheck
return
type
(
obj
.
get_tensor
(
not
rowwise
))
not
in
[
# pylint: disable=unidiomatic-typecheck
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -643,3 +643,5 @@ if torch_version() >= (2, 4, 0):
...
@@ -643,3 +643,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx
=
functools
.
partial
(
torch
.
amp
.
autocast
,
device_type
=
"cuda"
)
gpu_autocast_ctx
=
functools
.
partial
(
torch
.
amp
.
autocast
,
device_type
=
"cuda"
)
else
:
else
:
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment