Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f796eb80
Commit
f796eb80
authored
Dec 31, 2025
by
wenjh
Browse files
Fix new gemm
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
04afba37
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
4 deletions
+2
-4
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+0
-1
tests/pytorch/test_gemm_autotune.py
tests/pytorch/test_gemm_autotune.py
+2
-3
No files found.
tests/pytorch/test_batched_linear.py
View file @
f796eb80
...
@@ -45,7 +45,6 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
...
@@ -45,7 +45,6 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
,
batchgemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
,
batchgemm
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
...
tests/pytorch/test_gemm_autotune.py
View file @
f796eb80
...
@@ -13,8 +13,7 @@ import warnings
...
@@ -13,8 +13,7 @@ import warnings
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.cpp_extensions
import
gemm
from
transformer_engine.pytorch.cpp_extensions.gemm
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
def
use_hipblaslt
():
def
use_hipblaslt
():
...
@@ -142,7 +141,7 @@ def run_gemm():
...
@@ -142,7 +141,7 @@ def run_gemm():
N
=
32
N
=
32
datatype
=
torch
.
float16
datatype
=
torch
.
float16
inp
=
torch
.
randn
((
N
,
N
),
device
=
"cuda"
,
dtype
=
datatype
)
inp
=
torch
.
randn
((
N
,
N
),
device
=
"cuda"
,
dtype
=
datatype
)
_
,
_
,
_
=
gemm
(
A
=
inp
,
B
=
inp
,
dtype
=
datatype
,
workspace
=
get_workspace
()
)
_
,
_
,
_
=
general_
gemm
(
A
=
inp
,
B
=
inp
,
dtype
=
datatype
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment