Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
28ef01ca
Unverified
Commit
28ef01ca
authored
Dec 10, 2025
by
thatPepe
Committed by
GitHub
Dec 10, 2025
Browse files
Merge pull request #717 from InfiniTensor/issue/716
issue/716: Add save feature for existing test cases
parents
e7e96a29
a8875c9a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
401 additions
and
41 deletions
+401
-41
test/infinicore/framework/__init__.py
test/infinicore/framework/__init__.py
+3
-0
test/infinicore/framework/config.py
test/infinicore/framework/config.py
+39
-18
test/infinicore/framework/reporter.py
test/infinicore/framework/reporter.py
+292
-0
test/infinicore/framework/runner.py
test/infinicore/framework/runner.py
+63
-4
test/infinicore/run.py
test/infinicore/run.py
+4
-19
No files found.
test/infinicore/framework/__init__.py
View file @
28ef01ca
...
...
@@ -2,6 +2,7 @@ from .base import TestConfig, TestRunner, BaseOperatorTest
from
.test_case
import
TestCase
,
TestResult
from
.benchmark
import
BenchmarkUtils
,
BenchmarkResult
from
.config
import
(
add_common_test_args
,
get_args
,
get_hardware_args_group
,
get_test_devices
,
...
...
@@ -36,7 +37,9 @@ __all__ = [
"TestConfig"
,
"TestResult"
,
"TestRunner"
,
"TestReporter"
,
# Core functions
"add_common_test_args"
,
"compare_results"
,
"convert_infinicore_to_torch"
,
"create_test_comparator"
,
...
...
test/infinicore/framework/config.py
View file @
28ef01ca
...
...
@@ -44,6 +44,42 @@ def get_hardware_args_group(parser):
return
hardware_group
def
add_common_test_args
(
parser
:
argparse
.
ArgumentParser
):
"""
Adds common test/execution arguments to the passed parser object.
Includes: bench, debug, verbose, save args.
"""
# Create an argument group to make help info clearer
group
=
parser
.
add_argument_group
(
"Common Execution Options"
)
group
.
add_argument
(
"--bench"
,
nargs
=
"?"
,
const
=
"both"
,
choices
=
[
"host"
,
"device"
,
"both"
],
help
=
"Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)"
,
)
group
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug mode for detailed tensor comparison"
,
)
group
.
add_argument
(
"--verbose"
,
action
=
"store_true"
,
help
=
"Enable verbose mode to stop on first error with full traceback"
,
)
group
.
add_argument
(
"--save"
,
nargs
=
"?"
,
const
=
"test_report.json"
,
default
=
None
,
help
=
"Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided."
,
)
def
get_args
():
"""Parse command line arguments for operator testing"""
...
...
@@ -77,14 +113,6 @@ Examples:
)
# Core testing options
parser
.
add_argument
(
"--bench"
,
nargs
=
"?"
,
const
=
"both"
,
choices
=
[
"host"
,
"device"
,
"both"
],
help
=
"Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)"
,
)
parser
.
add_argument
(
"--num_prerun"
,
type
=
lambda
x
:
max
(
0
,
int
(
x
)),
...
...
@@ -97,16 +125,9 @@ Examples:
default
=
1000
,
help
=
"Number of iterations for benchmarking (default: 1000)"
,
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug mode for detailed tensor comparison"
,
)
parser
.
add_argument
(
"--verbose"
,
action
=
"store_true"
,
help
=
"Enable verbose mode to stop on first error with full traceback"
,
)
# Call the common method to add arguments
add_common_test_args
(
parser
)
# Device options using shared hardware info
hardware_group
=
get_hardware_args_group
(
parser
)
...
...
test/infinicore/framework/reporter.py
0 → 100644
View file @
28ef01ca
import
json
import
os
from
datetime
import
datetime
from
typing
import
List
,
Dict
,
Any
,
Union
from
dataclasses
import
is_dataclass
from
.base
import
TensorSpec
from
.devices
import
InfiniDeviceEnum
class
TestReporter
:
"""
Handles report generation and file saving for test results.
"""
@
staticmethod
def
prepare_report_entry
(
op_name
:
str
,
test_cases
:
List
[
Any
],
args
:
Any
,
op_paths
:
Dict
[
str
,
str
],
results_list
:
List
[
Any
]
)
->
List
[
Dict
[
str
,
Any
]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map
=
{}
if
isinstance
(
results_list
,
list
):
results_map
=
{
i
:
res
for
i
,
res
in
enumerate
(
results_list
)}
elif
isinstance
(
results_list
,
dict
):
results_map
=
results_list
else
:
results_map
=
{
0
:
results_list
}
if
results_list
else
{}
# 2. Global Args
global_args
=
{
k
:
getattr
(
args
,
k
)
for
k
in
[
"bench"
,
"num_prerun"
,
"num_iterations"
,
"verbose"
,
"debug"
]
if
hasattr
(
args
,
k
)
}
grouped_entries
:
Dict
[
int
,
Dict
[
str
,
Any
]]
=
{}
# 3. Iterate Test Cases
for
idx
,
tc
in
enumerate
(
test_cases
):
res
=
results_map
.
get
(
idx
)
dev_id
=
getattr
(
res
,
"device"
,
0
)
if
res
else
0
# --- A. Initialize Group ---
if
dev_id
not
in
grouped_entries
:
device_id_map
=
{
v
:
k
for
k
,
v
in
vars
(
InfiniDeviceEnum
).
items
()
if
not
k
.
startswith
(
"_"
)}
dev_str
=
device_id_map
.
get
(
dev_id
,
str
(
dev_id
))
grouped_entries
[
dev_id
]
=
{
"operator"
:
op_name
,
"device"
:
dev_str
,
"torch_op"
:
op_paths
.
get
(
"torch"
)
or
"unknown"
,
"infinicore_op"
:
op_paths
.
get
(
"infinicore"
)
or
"unknown"
,
"args"
:
global_args
,
"testcases"
:
[]
}
# --- B. Build Kwargs ---
display_kwargs
=
{}
# B1. Process existing kwargs
for
k
,
v
in
tc
.
kwargs
.
items
():
# Handle Inplace: "out": index -> "out": "input_name"
if
k
==
"out"
and
isinstance
(
v
,
int
):
if
0
<=
v
<
len
(
tc
.
inputs
):
display_kwargs
[
k
]
=
tc
.
inputs
[
v
].
name
else
:
display_kwargs
[
k
]
=
f
"Invalid_Index_
{
v
}
"
else
:
display_kwargs
[
k
]
=
(
TestReporter
.
_spec_to_dict
(
v
)
if
isinstance
(
v
,
TensorSpec
)
else
v
)
# B2. Inject Outputs into Kwargs
if
hasattr
(
tc
,
"output_specs"
)
and
tc
.
output_specs
:
for
i
,
spec
in
enumerate
(
tc
.
output_specs
):
display_kwargs
[
f
"out_
{
i
}
"
]
=
TestReporter
.
_spec_to_dict
(
spec
)
elif
tc
.
output_spec
:
if
"out"
not
in
display_kwargs
:
display_kwargs
[
"out"
]
=
TestReporter
.
_spec_to_dict
(
tc
.
output_spec
)
# --- C. Build Test Case Dictionary ---
case_data
=
{
"description"
:
tc
.
description
,
"inputs"
:
[
TestReporter
.
_spec_to_dict
(
i
)
for
i
in
tc
.
inputs
],
"kwargs"
:
display_kwargs
,
"comparison_target"
:
tc
.
comparison_target
,
"tolerance"
:
tc
.
tolerance
,
}
# --- D. Inject Result ---
if
res
:
case_data
[
"result"
]
=
TestReporter
.
_fmt_result
(
res
)
else
:
case_data
[
"result"
]
=
{
"status"
:
{
"success"
:
False
,
"error"
:
"No result"
}}
grouped_entries
[
dev_id
][
"testcases"
].
append
(
case_data
)
return
list
(
grouped_entries
.
values
())
@
staticmethod
def
save_all_results
(
save_path
:
str
,
total_results
:
List
[
Dict
[
str
,
Any
]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory
,
filename
=
os
.
path
.
split
(
save_path
)
name
,
ext
=
os
.
path
.
splitext
(
filename
)
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S_%f"
)[:
-
3
]
final_path
=
os
.
path
.
join
(
directory
,
f
"
{
name
}
_
{
timestamp
}{
ext
}
"
)
# Define indentation levels for cleaner code
indent_4
=
' '
*
4
indent_8
=
' '
*
8
indent_12
=
' '
*
12
indent_16
=
' '
*
16
indent_20
=
' '
*
20
print
(
f
"💾 Saving to:
{
final_path
}
"
)
try
:
with
open
(
final_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
"[
\n
"
)
for
i
,
entry
in
enumerate
(
total_results
):
f
.
write
(
f
"
{
indent_4
}
{{
\n
"
)
keys
=
list
(
entry
.
keys
())
for
j
,
key
in
enumerate
(
keys
):
val
=
entry
[
key
]
comma
=
","
if
j
<
len
(
keys
)
-
1
else
""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if
key
==
"testcases"
and
isinstance
(
val
,
list
):
f
.
write
(
f
'
{
indent_8
}
"
{
key
}
": [
\n
'
)
for
c_idx
,
case_item
in
enumerate
(
val
):
f
.
write
(
f
"
{
indent_12
}
{{
\n
"
)
case_keys
=
list
(
case_item
.
keys
())
for
k_idx
,
c_key
in
enumerate
(
case_keys
):
c_val
=
case_item
[
c_key
]
# [Logic A] Skip fields we merged manually after 'kwargs'
if
c_key
in
[
"comparison_target"
,
"tolerance"
]:
continue
# Check comma for standard logic (might be overridden below)
c_comma
=
","
if
k_idx
<
len
(
case_keys
)
-
1
else
""
# [Logic B] Handle 'kwargs' + Grouped Fields
if
c_key
==
"kwargs"
:
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter
.
_write_smart_field
(
f
,
c_key
,
c_val
,
indent_16
,
indent_20
,
close_comma
=
","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v
=
json
.
dumps
(
case_item
.
get
(
"comparison_target"
),
ensure_ascii
=
False
)
tol_v
=
json
.
dumps
(
case_item
.
get
(
"tolerance"
),
ensure_ascii
=
False
)
remaining_keys
=
[
k
for
k
in
case_keys
[
k_idx
+
1
:]
if
k
not
in
(
"comparison_target"
,
"tolerance"
)]
line_comma
=
","
if
remaining_keys
else
""
f
.
write
(
f
'
{
indent_16
}
"comparison_target":
{
cmp_v
}
, "tolerance":
{
tol_v
}{
line_comma
}
\n
'
)
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if
c_key
==
"inputs"
and
isinstance
(
c_val
,
list
):
TestReporter
.
_write_smart_field
(
f
,
c_key
,
c_val
,
indent_16
,
indent_20
,
close_comma
=
c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else
:
c_val_str
=
json
.
dumps
(
c_val
,
ensure_ascii
=
False
)
f
.
write
(
f
'
{
indent_16
}
"
{
c_key
}
":
{
c_val_str
}{
c_comma
}
\n
'
)
close_comma
=
","
if
c_idx
<
len
(
val
)
-
1
else
""
f
.
write
(
f
"
{
indent_12
}
}}
{
close_comma
}
\n
"
)
f
.
write
(
f
"
{
indent_8
}
]
{
comma
}
\n
"
)
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else
:
k_str
=
json
.
dumps
(
key
,
ensure_ascii
=
False
)
v_str
=
json
.
dumps
(
val
,
ensure_ascii
=
False
)
f
.
write
(
f
"
{
indent_8
}{
k_str
}
:
{
v_str
}{
comma
}
\n
"
)
if
i
<
len
(
total_results
)
-
1
:
f
.
write
(
f
"
{
indent_4
}
}},
\n
"
)
else
:
f
.
write
(
f
"
{
indent_4
}
}}
\n
"
)
f
.
write
(
"]
\n
"
)
print
(
f
" ✅ Saved (Structure Matched)."
)
except
Exception
as
e
:
import
traceback
;
traceback
.
print_exc
()
print
(
f
" ❌ Save failed:
{
e
}
"
)
# --- Internal Helpers ---
@
staticmethod
def
_write_smart_field
(
f
,
key
,
value
,
indent
,
sub_indent
,
close_comma
=
""
):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
if
len
(
compact_json
)
<=
180
:
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
compact_json
}{
close_comma
}
\n
'
)
return
# 2. Fill/Flow Mode
is_dict
=
isinstance
(
value
,
dict
)
open_char
=
'{'
if
is_dict
else
'['
close_char
=
'}'
if
is_dict
else
']'
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
open_char
}
'
)
# Normalize items for iteration
if
is_dict
:
items
=
list
(
value
.
items
())
else
:
items
=
value
# List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len
=
len
(
indent
)
+
len
(
f
'"
{
key
}
":
{
open_char
}
'
)
for
i
,
item
in
enumerate
(
items
):
# Format individual item string
if
is_dict
:
k
,
v
=
item
val_str
=
json
.
dumps
(
v
,
ensure_ascii
=
False
)
item_str
=
f
'"
{
k
}
":
{
val_str
}
'
else
:
item_str
=
json
.
dumps
(
item
,
ensure_ascii
=
False
)
is_last
=
(
i
==
len
(
items
)
-
1
)
item_comma
=
""
if
is_last
else
", "
# Predict new length: current + item + comma
if
current_len
+
len
(
item_str
)
+
len
(
item_comma
)
>
180
:
# Wrap to new line
f
.
write
(
f
'
\n
{
sub_indent
}
'
)
current_len
=
len
(
sub_indent
)
f
.
write
(
f
'
{
item_str
}{
item_comma
}
'
)
current_len
+=
len
(
item_str
)
+
len
(
item_comma
)
f
.
write
(
f
'
{
close_char
}{
close_comma
}
\n
'
)
@
staticmethod
def
_spec_to_dict
(
s
):
return
{
"name"
:
getattr
(
s
,
"name"
,
"unknown"
),
"shape"
:
list
(
s
.
shape
)
if
s
.
shape
else
None
,
"dtype"
:
str
(
s
.
dtype
).
split
(
"."
)[
-
1
],
"strides"
:
list
(
s
.
strides
)
if
s
.
strides
else
None
,
}
@
staticmethod
def
_fmt_result
(
res
):
if
not
(
is_dataclass
(
res
)
or
hasattr
(
res
,
"success"
)):
return
str
(
res
)
get_time
=
lambda
k
:
round
(
getattr
(
res
,
k
,
0.0
),
4
)
return
{
"status"
:
{
"success"
:
getattr
(
res
,
"success"
,
False
),
"error"
:
getattr
(
res
,
"error_message"
,
""
),
},
"perf_ms"
:
{
"torch"
:
{
"host"
:
get_time
(
"torch_host_time"
),
"device"
:
get_time
(
"torch_device_time"
),
},
"infinicore"
:
{
"host"
:
get_time
(
"infini_host_time"
),
"device"
:
get_time
(
"infini_device_time"
),
},
},
}
test/infinicore/framework/runner.py
View file @
28ef01ca
...
...
@@ -3,19 +3,22 @@ Generic test runner that handles the common execution flow for all operators
"""
import
sys
import
os
import
inspect
import
re
from
.
import
TestConfig
,
TestRunner
,
get_args
,
get_test_devices
from
.reporter
import
TestReporter
class
GenericTestRunner
:
"""Generic test runner that handles the common execution flow"""
def
__init__
(
self
,
operator_test_class
):
def
__init__
(
self
,
operator_test_class
,
args
=
None
):
"""
Args:
operator_test_class: A class that implements BaseOperatorTest interface
"""
self
.
operator_test
=
operator_test_class
()
self
.
args
=
get_args
()
self
.
args
=
args
or
get_args
()
def
run
(
self
):
"""Execute the complete test suite
...
...
@@ -50,6 +53,9 @@ class GenericTestRunner:
# summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed
=
runner
.
print_summary
()
if
getattr
(
self
.
args
,
'save'
,
None
):
self
.
_save_report
(
runner
)
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
...
...
@@ -62,5 +68,58 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success
,
runner
=
self
.
run
()
success
,
runner
=
self
.
run
()
sys
.
exit
(
0
if
success
else
1
)
def
_save_report
(
self
,
runner
):
"""
Helper method to collect metadata and trigger report saving.
"""
try
:
# 1. Prepare metadata (Paths)
t_path
=
self
.
_infer_op_path
(
self
.
operator_test
.
torch_operator
,
"torch"
)
i_path
=
self
.
_infer_op_path
(
self
.
operator_test
.
infinicore_operator
,
"infinicore"
)
op_paths
=
{
"torch"
:
t_path
,
"infinicore"
:
i_path
}
# 2. Generate Report Entries
entries
=
TestReporter
.
prepare_report_entry
(
op_name
=
self
.
operator_test
.
operator_name
,
test_cases
=
self
.
operator_test
.
test_cases
,
args
=
self
.
args
,
op_paths
=
op_paths
,
results_list
=
runner
.
test_results
)
# 4. Save to File
TestReporter
.
save_all_results
(
self
.
args
.
save
,
entries
)
except
Exception
as
e
:
import
traceback
;
traceback
.
print_exc
()
print
(
f
"⚠️ Failed to save report:
{
e
}
"
)
def
_infer_op_path
(
self
,
method
,
lib_prefix
):
"""
Introspects the method source code to find calls like 'torch.add' or 'infinicore.mul'.
Returns the full path string (e.g., 'torch.add') or None if not found.
"""
try
:
source
=
inspect
.
getsource
(
method
)
# Regex to find 'lib.func' or 'lib.submodule.func'
# Matches: 'torch.add', 'torch.nn.functional.relu'
pattern
=
re
.
compile
(
rf
"\b
{
lib_prefix
}
\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)"
,
re
.
IGNORECASE
)
match
=
pattern
.
search
(
source
)
if
match
:
# Return the matched string exactly as found in source code
# or normalize it (e.g. lowercase lib_prefix + match)
return
f
"
{
lib_prefix
}
.
{
match
.
group
(
1
)
}
"
except
Exception
:
# Handle cases where source is not available (e.g. compiled modules)
pass
return
None
test/infinicore/run.py
View file @
28ef01ca
...
...
@@ -5,7 +5,7 @@ import traceback
from
pathlib
import
Path
import
importlib.util
from
framework
import
get_hardware_args_group
from
framework
import
get_hardware_args_group
,
add_common_test_args
def
find_ops_directory
(
location
=
None
):
...
...
@@ -650,24 +650,9 @@ def main():
action
=
"store_true"
,
help
=
"List all available test files without running them"
,
)
parser
.
add_argument
(
"--verbose"
,
action
=
"store_true"
,
help
=
"Enable verbose mode to stop on first error with full traceback"
,
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
help
=
"Enable debug mode to debug value mismatches"
,
)
parser
.
add_argument
(
"--bench"
,
nargs
=
"?"
,
const
=
"both"
,
choices
=
[
"host"
,
"device"
,
"both"
],
help
=
"Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)"
,
)
# Call common method to add shared arguments (bench, debug, verbose, save...)
add_common_test_args
(
parser
)
get_hardware_args_group
(
parser
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment