Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ca1e98b6
Commit
ca1e98b6
authored
Sep 09, 2025
by
wenjh
Browse files
Fix float8 blockwise gemm tests with accumulator
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
065160ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+6
-3
No files found.
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
ca1e98b6
...
...
@@ -56,7 +56,10 @@ def cublas_gemm_fp8_blockwise_case(
if
use_bias
or
use_gelu
:
pytest
.
skip
(
"Bias and GELU not supported in int8 simulation mode on ROCm."
)
if
not
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
)):
pytest
.
skip
(
"Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
)
pytest
.
skip
(
"Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
)
if
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)):
if
accumulate
:
pytest
.
skip
(
"Accumulation not supported in fwd and xgrad block scaling in int8 simulation mode on ROCm."
)
if
x_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_x_1d_scaled
or
is_w_1d_scaled
):
...
...
@@ -185,7 +188,7 @@ def cublas_gemm_fp8_blockwise_case(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
)
else
:
assert
False
,
"Only
1Dx2D, 1Dx1D
, and
2Dx1D
block scaling supported in int8 simulation mode on ROCm."
assert
False
,
"Only
fwd, xgrad
, and
wgrad
block scaling supported in int8 simulation mode on ROCm."
else
:
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
y
,
_
=
w8a8_block_int8_matmul
(
...
...
@@ -204,7 +207,7 @@ def cublas_gemm_fp8_blockwise_case(
output_dtype
=
out_dtype
)
else
:
assert
False
,
"Only
1Dx2D, 1Dx1D
, and
2Dx1D
block scaling supported in int8 simulation mode on ROCm."
assert
False
,
"Only
fwd, xgrad
, and
wgrad
block scaling supported in int8 simulation mode on ROCm."
else
:
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment