Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
6d461a10
Commit
6d461a10
authored
Jun 09, 2025
by
yuguo
Browse files
[DCU] fix
parent
0a8072fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
8 deletions
+12
-8
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+12
-8
No files found.
tests/pytorch/references/blockwise_fp8_gemm_reference.py
View file @
6d461a10
...
@@ -7,6 +7,7 @@ from typing import Tuple
...
@@ -7,6 +7,7 @@ from typing import Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
@
triton
.
jit
@
triton
.
jit
...
@@ -195,14 +196,17 @@ class CuBLASRefBlockwiseGemm:
...
@@ -195,14 +196,17 @@ class CuBLASRefBlockwiseGemm:
# Perform qgemm with scaling factors fused in the GEMM
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
one
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
y_partial
=
torch
.
_scaled_mm
(
if
IS_HIP_EXTENSION
:
qx_block
,
y_partial
=
torch
.
mm
(
qx_block
.
to
(
torch
.
float
),
qw_block
.
t
().
to
(
torch
.
float
))
qw_block
.
t
(),
else
:
scale_a
=
one
,
y_partial
=
torch
.
_scaled_mm
(
scale_b
=
one
,
qx_block
,
out_dtype
=
torch
.
float32
,
qw_block
.
t
(),
use_fast_accum
=
not
use_split_accumulator
,
scale_a
=
one
,
)
scale_b
=
one
,
out_dtype
=
torch
.
float32
,
use_fast_accum
=
not
use_split_accumulator
,
)
# Accumulate the partial result
# Accumulate the partial result
if
is_a_1d_scaled
and
is_b_1d_scaled
:
if
is_a_1d_scaled
and
is_b_1d_scaled
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment