Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
065160ab
Commit
065160ab
authored
Sep 03, 2025
by
wenjh
Browse files
Add int8 blockwise gemm test to float8
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
0c461880
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
37 deletions
+88
-37
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-1
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+87
-36
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
065160ab
...
@@ -40,7 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetenso
...
@@ -40,7 +40,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetenso
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
NVTE_INT8_SIM_FP8
=
1
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
065160ab
...
@@ -6,16 +6,23 @@ import pytest
...
@@ -6,16 +6,23 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
try
:
import
lightop
except
ImportError
:
pass
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
)
)
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
fp8_blockwise_gemm_supported
()
->
bool
:
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
...
@@ -45,6 +52,11 @@ def cublas_gemm_fp8_blockwise_case(
...
@@ -45,6 +52,11 @@ def cublas_gemm_fp8_blockwise_case(
atol
:
float
=
0.0
,
atol
:
float
=
0.0
,
rtol
:
float
=
0.0
rtol
:
float
=
0.0
):
):
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
if
use_bias
or
use_gelu
:
pytest
.
skip
(
"Bias and GELU not supported in int8 simulation mode on ROCm."
)
if
not
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
)):
pytest
.
skip
(
"Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
)
if
x_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
if
x_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_x_1d_scaled
or
is_w_1d_scaled
):
if
not
(
is_x_1d_scaled
or
is_w_1d_scaled
):
...
@@ -157,27 +169,64 @@ def cublas_gemm_fp8_blockwise_case(
...
@@ -157,27 +169,64 @@ def cublas_gemm_fp8_blockwise_case(
aux_tensor_ref
=
aux_tensor
.
clone
()
if
use_gelu
else
None
aux_tensor_ref
=
aux_tensor
.
clone
()
if
use_gelu
else
None
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
# We are just capturing out.
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
tex
.
generic_gemm
(
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
qw
,
y
=
lightop
.
gemm_w8a8_asm
(
transa
,
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
qx
,
)
transb
,
elif
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
out
.
clone
()
if
accumulate
else
None
,
y
=
lightop
.
gemm_w8a8_xgrad_asm
(
out_quantizer
,
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
TE_DType
[
out_dtype
],
)
bias
,
elif
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
):
bias_dtype
,
y
=
lightop
.
gemm_w8a8_wgrad_asm
(
use_gelu
,
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
aux_tensor
,
)
use_grad
,
else
:
workspace
,
assert
False
,
"Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
workspace
.
shape
[
0
],
else
:
accumulate
,
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
use_split_accumulator
,
y
,
_
=
w8a8_block_int8_matmul
(
)[
0
]
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
elif
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
elif
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
):
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
else
:
assert
False
,
"Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
else
:
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y
=
tex
.
generic_gemm
(
qw
,
transa
,
qx
,
transb
,
out
.
clone
()
if
accumulate
else
None
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# just in case of accumulation, make sure y_ref and y are not the same tensor
# just in case of accumulation, make sure y_ref and y are not the same tensor
assert
y_ref
is
not
y
,
"y_ref and y should not be the same tensor"
assert
y_ref
is
not
y
,
"y_ref and y should not be the same tensor"
...
@@ -227,6 +276,8 @@ def cublas_gemm_test_constraint_enforced(
...
@@ -227,6 +276,8 @@ def cublas_gemm_test_constraint_enforced(
expected_err_msg
=
"CUBLAS_STATUS_NOT_SUPPORTED"
,
expected_err_msg
=
"CUBLAS_STATUS_NOT_SUPPORTED"
,
expected_err_cls
=
RuntimeError
expected_err_cls
=
RuntimeError
):
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ROCm does not support cuBLAS GEMM. No need to test constraint enforcement."
)
if
not
fp8_blockwise_gemm_supported
():
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
# Setup device and random seed
...
@@ -333,8 +384,8 @@ def cublas_gemm_test_constraint_enforced(
...
@@ -333,8 +384,8 @@ def cublas_gemm_test_constraint_enforced(
(
1024
,
4096
,
1024
),
(
1024
,
4096
,
1024
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
...
@@ -389,8 +440,8 @@ def test_cublas_gemm_fp8_blockwise_shape_varying(
...
@@ -389,8 +440,8 @@ def test_cublas_gemm_fp8_blockwise_shape_varying(
(
320
,
256
,
336
),
(
320
,
256
,
336
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
,
"uniform"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
,
"uniform"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-28
,
1
,
1e3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-28
,
1
,
1e3
],
ids
=
str
)
...
@@ -449,8 +500,8 @@ def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying(
...
@@ -449,8 +500,8 @@ def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying(
(
256
,
256
,
256
),
(
256
,
256
,
256
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-3
],
ids
=
str
)
...
@@ -511,8 +562,8 @@ def test_cublas_gemm_fp8_blockwise_bias(
...
@@ -511,8 +562,8 @@ def test_cublas_gemm_fp8_blockwise_bias(
(
4096
,
128
,
4096
),
(
4096
,
128
,
4096
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
...
@@ -584,8 +635,8 @@ def test_cublas_gemm_fp8_blockwise_columnwise(
...
@@ -584,8 +635,8 @@ def test_cublas_gemm_fp8_blockwise_columnwise(
(
256
,
256
,
256
),
(
256
,
256
,
256
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
...
@@ -913,8 +964,8 @@ def test_illegal_2D_by_2D_enforced(
...
@@ -913,8 +964,8 @@ def test_illegal_2D_by_2D_enforced(
(
256
,
128
,
252
,
False
,
False
),
(
256
,
128
,
252
,
False
,
False
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
]
if
not
int8_simulation_fp8
else
[
torch
.
int8
]
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
],
ids
=
[
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
],
ids
=
[
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment