Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f7c66e28
"tests/cpp/util/test_string.cpp" did not exist on "0d25199158abe9167ecf5ceec655dab24c7c626e"
Commit
f7c66e28
authored
Nov 03, 2025
by
zhaochao
Browse files
[DCU] fix some bugs in test_numerics.py
parent
87682fe2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
5 deletions
+56
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+55
-4
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
No files found.
tests/pytorch/test_numerics.py
View file @
f7c66e28
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
...
@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
...
@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
...
@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8_model_params
=
False
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
...
@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass
=
False
,
use_cutlass
=
False
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
...
@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
sequential_linear
,
num_gemms
,
num_gemms
,
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
...
@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
delay_wgrad_compute
,
)
)
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"0"
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
for
o
,
o_ref
in
zip
(
outputs
,
outputs_ref
):
if
use_cutlass
:
if
use_cutlass
:
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
...
transformer_engine/pytorch/module/_common.py
View file @
f7c66e28
...
@@ -53,7 +53,7 @@ def apply_normalization(
...
@@ -53,7 +53,7 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
):
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
)
and
not
zero_centered_gamma
:
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
return
out
,
None
,
rsigma
return
out
,
None
,
rsigma
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment