Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
bbf18d2c
Unverified
Commit
bbf18d2c
authored
Dec 12, 2025
by
thatPepe
Committed by
GitHub
Dec 12, 2025
Browse files
Merge pull request #758 from InfiniTensor/issue/757
issue/757 - support equal_nan in test debug
parents
f53b8435
7a55b415
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
28 deletions
+81
-28
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+12
-2
test/infinicore/framework/config.py
test/infinicore/framework/config.py
+15
-4
test/infinicore/framework/runner.py
test/infinicore/framework/runner.py
+20
-15
test/infinicore/framework/utils.py
test/infinicore/framework/utils.py
+34
-7
No files found.
test/infinicore/framework/base.py
View file @
bbf18d2c
...
...
@@ -30,8 +30,10 @@ class TestConfig:
num_prerun
=
10
,
num_iterations
=
1000
,
verbose
=
False
,
equal_nan
=
False
,
):
self
.
debug
=
debug
self
.
equal_nan
=
equal_nan
self
.
bench
=
bench
self
.
num_prerun
=
num_prerun
self
.
num_iterations
=
num_iterations
...
...
@@ -540,7 +542,11 @@ class BaseOperatorTest(ABC):
rtol
=
test_case
.
tolerance
.
get
(
"rtol"
,
1e-3
)
compare_fn
=
create_test_comparator
(
config
,
atol
,
rtol
,
f
"
{
test_case
.
description
}
- output_
{
i
}
"
config
,
atol
,
rtol
,
f
"
{
test_case
.
description
}
- output_
{
i
}
"
,
equal_nan
=
config
.
equal_nan
,
)
is_valid
=
compare_fn
(
infini_out
,
torch_out
)
...
...
@@ -589,7 +595,11 @@ class BaseOperatorTest(ABC):
rtol
=
test_case
.
tolerance
.
get
(
"rtol"
,
1e-3
)
compare_fn
=
create_test_comparator
(
config
,
atol
,
rtol
,
test_case
.
description
config
,
atol
,
rtol
,
test_case
.
description
,
equal_nan
=
config
.
equal_nan
,
)
is_valid
=
compare_fn
(
infini_comparison
,
torch_comparison
)
...
...
test/infinicore/framework/config.py
View file @
bbf18d2c
...
...
@@ -44,6 +44,7 @@ def get_hardware_args_group(parser):
return
hardware_group
def
add_common_test_args
(
parser
:
argparse
.
ArgumentParser
):
"""
Adds common test/execution arguments to the passed parser object.
...
...
@@ -60,13 +61,19 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help
=
"Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)"
,
)
group
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug mode for detailed tensor comparison"
,
)
group
.
add_argument
(
"--eq_nan"
,
action
=
"store_true"
,
help
=
"Enable equal_nan for tensor comparison"
,
)
group
.
add_argument
(
"--verbose"
,
action
=
"store_true"
,
...
...
@@ -81,6 +88,7 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help
=
"Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided."
,
)
def
get_args
():
"""Parse command line arguments for operator testing"""
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -100,9 +108,12 @@ Examples:
# Run with benchmarking - device timing only
python test_operator.py --nvidia --bench device
# Run with debug mode on multiple devices
# Run with
basic
debug mode on multiple devices
python test_operator.py --cpu --nvidia --debug
# Run with eq_nan debug mode to treat NaN as equal
python test_operator.py --cpu --nvidia --debug --eq_nan
# Run with verbose mode to stop on first error with full traceback
python test_operator.py --cpu --nvidia --verbose
...
...
@@ -216,7 +227,7 @@ def get_test_devices(args):
devices_to_test
.
append
(
InfiniDeviceEnum
.
HYGON
)
except
ImportError
:
print
(
"Warning: Hygon DCU support not available"
)
if
args
.
qy
:
try
:
# Iluvatar GPU detection
...
...
test/infinicore/framework/runner.py
View file @
bbf18d2c
...
...
@@ -9,6 +9,7 @@ import re
from
.
import
TestConfig
,
TestRunner
,
get_args
,
get_test_devices
from
.reporter
import
TestReporter
class
GenericTestRunner
:
"""Generic test runner that handles the common execution flow"""
...
...
@@ -33,7 +34,8 @@ class GenericTestRunner:
bench
=
self
.
args
.
bench
,
num_prerun
=
self
.
args
.
num_prerun
,
num_iterations
=
self
.
args
.
num_iterations
,
verbose
=
self
.
args
.
verbose
,
# Pass verbose flag to TestConfig
verbose
=
self
.
args
.
verbose
,
equal_nan
=
self
.
args
.
eq_nan
,
)
runner
=
TestRunner
(
self
.
operator_test
.
test_cases
,
config
)
...
...
@@ -53,9 +55,9 @@ class GenericTestRunner:
# summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed
=
runner
.
print_summary
()
if
getattr
(
self
.
args
,
'
save
'
,
None
):
if
getattr
(
self
.
args
,
"
save
"
,
None
):
self
.
_save_report
(
runner
)
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
...
...
@@ -68,7 +70,7 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success
,
runner
=
self
.
run
()
success
,
runner
=
self
.
run
()
sys
.
exit
(
0
if
success
else
1
)
...
...
@@ -77,15 +79,14 @@ class GenericTestRunner:
Helper method to collect metadata and trigger report saving.
"""
try
:
# 1. Prepare metadata (Paths)
t_path
=
self
.
_infer_op_path
(
self
.
operator_test
.
torch_operator
,
"torch"
)
i_path
=
self
.
_infer_op_path
(
self
.
operator_test
.
infinicore_operator
,
"infinicore"
)
op_paths
=
{
"torch"
:
t_path
,
"infinicore"
:
i_path
}
i_path
=
self
.
_infer_op_path
(
self
.
operator_test
.
infinicore_operator
,
"infinicore"
)
op_paths
=
{
"torch"
:
t_path
,
"infinicore"
:
i_path
}
# 2. Generate Report Entries
entries
=
TestReporter
.
prepare_report_entry
(
...
...
@@ -93,14 +94,16 @@ class GenericTestRunner:
test_cases
=
self
.
operator_test
.
test_cases
,
args
=
self
.
args
,
op_paths
=
op_paths
,
results_list
=
runner
.
test_results
results_list
=
runner
.
test_results
,
)
# 4. Save to File
TestReporter
.
save_all_results
(
self
.
args
.
save
,
entries
)
except
Exception
as
e
:
import
traceback
;
traceback
.
print_exc
()
import
traceback
traceback
.
print_exc
()
print
(
f
"⚠️ Failed to save report:
{
e
}
"
)
def
_infer_op_path
(
self
,
method
,
lib_prefix
):
...
...
@@ -113,7 +116,9 @@ class GenericTestRunner:
# Regex to find 'lib.func' or 'lib.submodule.func'
# Matches: 'torch.add', 'torch.nn.functional.relu'
pattern
=
re
.
compile
(
rf
"\b
{
lib_prefix
}
\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)"
,
re
.
IGNORECASE
)
pattern
=
re
.
compile
(
rf
"\b
{
lib_prefix
}
\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)"
,
re
.
IGNORECASE
)
match
=
pattern
.
search
(
source
)
if
match
:
# Return the matched string exactly as found in source code
...
...
test/infinicore/framework/utils.py
View file @
bbf18d2c
...
...
@@ -91,6 +91,7 @@ def print_discrepancy(
print
(
f
" - Desired dtype:
{
expected
.
dtype
}
"
)
print
(
f
" - Atol:
{
atol
}
"
)
print
(
f
" - Rtol:
{
rtol
}
"
)
print
(
f
" - Equal NaN:
{
equal_nan
}
"
)
print
(
f
" - Mismatched elements:
{
len
(
diff_indices
)
}
/
{
actual
.
numel
()
}
(
{
len
(
diff_indices
)
/
actual
.
numel
()
*
100
}
%)"
)
...
...
@@ -169,7 +170,7 @@ def convert_infinicore_to_torch(infini_result):
def
compare_results
(
infini_result
,
torch_result
,
atol
=
1e-5
,
rtol
=
1e-5
,
debug_mode
=
False
infini_result
,
torch_result
,
atol
=
1e-5
,
rtol
=
1e-5
,
equal_nan
=
False
,
debug_mode
=
False
):
"""
Generic function to compare infinicore result with PyTorch reference result
...
...
@@ -180,6 +181,7 @@ def compare_results(
torch_result: PyTorch tensor reference result (single or tuple)
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
equal_nan: whether to treat NaN as equal
debug_mode: whether to enable debug output
Returns:
...
...
@@ -194,7 +196,9 @@ def compare_results(
all_match
=
True
for
i
,
(
infini_out
,
torch_out
)
in
enumerate
(
zip
(
infini_result
,
torch_result
)):
match
=
compare_results
(
infini_out
,
torch_out
,
atol
,
rtol
,
debug_mode
)
match
=
compare_results
(
infini_out
,
torch_out
,
atol
,
rtol
,
equal_nan
,
debug_mode
)
all_match
=
all_match
and
match
return
all_match
...
...
@@ -241,7 +245,13 @@ def compare_results(
# Debug mode: detailed comparison
if
debug_mode
:
debug
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
)
debug
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
,
equal_nan
=
equal_nan
,
)
# Choose comparison method based on data type
if
is_integer_dtype
(
torch_result_from_infini
.
dtype
)
or
is_integer_dtype
(
...
...
@@ -257,10 +267,18 @@ def compare_results(
):
# Complex number comparison - compare real and imaginary parts separately
real_close
=
torch
.
allclose
(
torch_result_from_infini
.
real
,
torch_result
.
real
,
atol
=
atol
,
rtol
=
rtol
torch_result_from_infini
.
real
,
torch_result
.
real
,
atol
=
atol
,
rtol
=
rtol
,
equal_nan
=
equal_nan
,
)
imag_close
=
torch
.
allclose
(
torch_result_from_infini
.
imag
,
torch_result
.
imag
,
atol
=
atol
,
rtol
=
rtol
torch_result_from_infini
.
imag
,
torch_result
.
imag
,
atol
=
atol
,
rtol
=
rtol
,
equal_nan
=
equal_nan
,
)
result_equal
=
real_close
and
imag_close
if
debug_mode
and
not
result_equal
:
...
...
@@ -273,11 +291,15 @@ def compare_results(
else
:
# Tolerance-based comparison for floating-point types
return
torch
.
allclose
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
,
equal_nan
=
equal_nan
,
)
def
create_test_comparator
(
config
,
atol
,
rtol
,
mode_name
=
""
):
def
create_test_comparator
(
config
,
atol
,
rtol
,
mode_name
=
""
,
equal_nan
=
False
):
"""
Create a test-specific comparison function
...
...
@@ -286,6 +308,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
mode_name: operation mode name for debug output
equal_nan: whether to treat NaN as equal
Returns:
callable: function that takes (infini_result, torch_result) and returns bool
...
...
@@ -294,6 +317,9 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
def
compare_test_results
(
infini_result
,
torch_result
):
if
config
.
debug
and
mode_name
:
print
(
f
"
\033
[94mDEBUG INFO -
{
mode_name
}
:
\033
[0m"
)
print
(
f
"
\033
[94m Equal NaN:
{
'enabled'
if
equal_nan
else
'disabled'
}
\033
[0m"
)
# For integer types, override tolerance to require exact equality
actual_atol
=
atol
...
...
@@ -316,6 +342,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
torch_result
,
atol
=
actual_atol
,
rtol
=
actual_rtol
,
equal_nan
=
equal_nan
,
debug_mode
=
config
.
debug
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment