Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
6d461a10
Commit
6d461a10
authored
Jun 09, 2025
by
yuguo
Browse files
[DCU] fix
parent
0a8072fa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
8 deletions
+12
-8
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+12
-8
No files found.
tests/pytorch/references/blockwise_fp8_gemm_reference.py
View file @
6d461a10
...
...
@@ -7,6 +7,7 @@ from typing import Tuple
import
torch
import
triton
import
triton.language
as
tl
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
@
triton
.
jit
...
...
@@ -195,6 +196,9 @@ class CuBLASRefBlockwiseGemm:
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
if
IS_HIP_EXTENSION
:
y_partial
=
torch
.
mm
(
qx_block
.
to
(
torch
.
float
),
qw_block
.
t
().
to
(
torch
.
float
))
else
:
y_partial
=
torch
.
_scaled_mm
(
qx_block
,
qw_block
.
t
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment