Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b32741e2
Commit
b32741e2
authored
Jul 15, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
a5892578
148b5bea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
5 deletions
+4
-5
transformer_engine/__init__.py
transformer_engine/__init__.py
+1
-4
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+3
-1
No files found.
transformer_engine/__init__.py
View file @
b32741e2
...
@@ -13,10 +13,7 @@ import transformer_engine.common
...
@@ -13,10 +13,7 @@ import transformer_engine.common
try
:
try
:
from
.
import
pytorch
from
.
import
pytorch
except
ImportError
:
except
ImportError
:
try
:
pass
from
.
import
pytorch
except
ImportError
:
pass
except
FileNotFoundError
as
e
:
except
FileNotFoundError
as
e
:
if
"Could not find shared object file"
not
in
str
(
e
):
if
"Could not find shared object file"
not
in
str
(
e
):
raise
e
# Unexpected error
raise
e
# Unexpected error
...
...
transformer_engine/pytorch/utils.py
View file @
b32741e2
...
@@ -13,7 +13,6 @@ import torch
...
@@ -13,7 +13,6 @@ import torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.
import
torch_version
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
...
@@ -558,6 +557,7 @@ def round_up_to_nearest_multiple(value, multiple):
...
@@ -558,6 +557,7 @@ def round_up_to_nearest_multiple(value, multiple):
def
needs_quantized_gemm
(
obj
,
rowwise
=
True
):
def
needs_quantized_gemm
(
obj
,
rowwise
=
True
):
"""Used to check if obj will need quantized gemm or normal gemm."""
"""Used to check if obj will need quantized gemm or normal gemm."""
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
if
isinstance
(
obj
,
DebugQuantizedTensor
):
if
isinstance
(
obj
,
DebugQuantizedTensor
):
return
type
(
obj
.
get_tensor
(
not
rowwise
))
not
in
[
# pylint: disable=unidiomatic-typecheck
return
type
(
obj
.
get_tensor
(
not
rowwise
))
not
in
[
# pylint: disable=unidiomatic-typecheck
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -652,3 +652,5 @@ if torch_version() >= (2, 4, 0):
...
@@ -652,3 +652,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx
=
functools
.
partial
(
torch
.
amp
.
autocast
,
device_type
=
"cuda"
)
gpu_autocast_ctx
=
functools
.
partial
(
torch
.
amp
.
autocast
,
device_type
=
"cuda"
)
else
:
else
:
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment