Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
12cde8eb
Unverified
Commit
12cde8eb
authored
Dec 24, 2025
by
thatPepe
Committed by
GitHub
Dec 24, 2025
Browse files
Merge pull request #788 from InfiniTensor/issue/787
issue/787 - Split run ops test logic and fix kwargs name in report
parents
62fe6999
7aece930
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1068 additions
and
1218 deletions
+1068
-1218
test/infinicore/framework/__init__.py
test/infinicore/framework/__init__.py
+17
-6
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+9
-9
test/infinicore/framework/benchmark.py
test/infinicore/framework/benchmark.py
+1
-1
test/infinicore/framework/entities.py
test/infinicore/framework/entities.py
+0
-15
test/infinicore/framework/executor.py
test/infinicore/framework/executor.py
+81
-0
test/infinicore/framework/reporter.py
test/infinicore/framework/reporter.py
+0
-292
test/infinicore/framework/results.py
test/infinicore/framework/results.py
+396
-0
test/infinicore/framework/runner.py
test/infinicore/framework/runner.py
+4
-3
test/infinicore/framework/tensor.py
test/infinicore/framework/tensor.py
+1
-1
test/infinicore/framework/utils/__init__.py
test/infinicore/framework/utils/__init__.py
+0
-0
test/infinicore/framework/utils/compare_utils.py
test/infinicore/framework/utils/compare_utils.py
+95
-242
test/infinicore/framework/utils/json_utils.py
test/infinicore/framework/utils/json_utils.py
+137
-0
test/infinicore/framework/utils/tensor_utils.py
test/infinicore/framework/utils/tensor_utils.py
+167
-0
test/infinicore/ops/adaptive_max_pool2d.py
test/infinicore/ops/adaptive_max_pool2d.py
+2
-1
test/infinicore/ops/embedding.py
test/infinicore/ops/embedding.py
+1
-1
test/infinicore/ops/random_sample.py
test/infinicore/ops/random_sample.py
+6
-6
test/infinicore/ops/sort.py
test/infinicore/ops/sort.py
+3
-2
test/infinicore/ops/std.py
test/infinicore/ops/std.py
+3
-2
test/infinicore/run.py
test/infinicore/run.py
+145
-637
No files found.
test/infinicore/framework/__init__.py
View file @
12cde8eb
from
.base
import
TestConfig
,
TestRunner
,
BaseOperatorTest
from
.
test_case
import
TestCase
,
TestResult
from
.
entities
import
TestCase
from
.benchmark
import
BenchmarkUtils
,
BenchmarkResult
from
.config
import
(
add_common_test_args
,
...
...
@@ -9,35 +9,44 @@ from .config import (
)
from
.datatypes
import
to_torch_dtype
,
to_infinicore_dtype
from
.devices
import
InfiniDeviceEnum
,
InfiniDeviceNames
,
torch_device_map
from
.results
import
TestTiming
,
OperatorResult
,
CaseResult
,
TestSummary
from
.runner
import
GenericTestRunner
from
.tensor
import
TensorSpec
,
TensorInitializer
from
.utils
import
(
from
.executor
import
TestExecutor
from
.utils.compare_utils
import
(
compare_results
,
create_test_comparator
,
debug
,
get_tolerance
,
)
from
.utils.json_utils
import
save_json_report
from
.utils.tensor_utils
import
(
infinicore_tensor_from_torch
,
rearrange_tensor
,
convert_infinicore_to_torch
,
rearrange_tensor
,
is_broadcast
,
is_integer_dtype
,
is_complex_dtype
,
is_floating_dtype
,
is_integer_dtype
,
)
__all__
=
[
# Core types and classes
"BaseOperatorTest"
,
"CaseResult"
,
"GenericTestRunner"
,
"InfiniDeviceEnum"
,
"InfiniDeviceNames"
,
"OperatorResult"
,
"TensorInitializer"
,
"TensorSpec"
,
"TestCase"
,
"TestConfig"
,
"TestResult"
,
"TestExecutor"
,
"TestSummary"
,
"TestRunner"
,
"Test
Reporter
"
,
"Test
Timing
"
,
# Core functions
"add_common_test_args"
,
"compare_results"
,
...
...
@@ -50,6 +59,8 @@ __all__ = [
"get_tolerance"
,
"infinicore_tensor_from_torch"
,
"rearrange_tensor"
,
# Json utilites
"save_json_report"
,
# Utility functions
"to_infinicore_dtype"
,
"to_torch_dtype"
,
...
...
test/infinicore/framework/base.py
View file @
12cde8eb
...
...
@@ -8,15 +8,15 @@ import infinicore
import
traceback
from
abc
import
ABC
,
abstractmethod
from
.
t
es
t_case
import
TestCase
,
Test
Result
from
.
r
es
ults
import
Case
Result
from
.datatypes
import
to_torch_dtype
,
to_infinicore_dtype
from
.devices
import
InfiniDeviceNames
,
torch_device_map
from
.tensor
import
TensorSpec
,
TensorInitializer
from
.utils
import
(
from
.utils
.tensor_utils
import
(
clone_torch_tensor
,
create_test_comparator
,
infinicore_tensor_from_torch
,
)
from
.utils.compare_utils
import
create_test_comparator
from
.benchmark
import
BenchmarkUtils
...
...
@@ -84,7 +84,7 @@ class TestRunner:
try
:
print
(
f
"
{
test_case
}
"
)
# Execute test and get
Test
Result object
# Execute test and get
Case
Result object
test_result
=
test_func
(
device
,
test_case
,
self
.
config
)
self
.
test_results
.
append
(
test_result
)
...
...
@@ -118,8 +118,8 @@ class TestRunner:
print
(
f
"
\033
[91m✗
\033
[0m
{
error_msg
}
"
)
self
.
failed_tests
.
append
(
error_msg
)
# Create a failed
Test
Result
failed_result
=
Test
Result
(
# Create a failed
Case
Result
failed_result
=
Case
Result
(
success
=
False
,
return_code
=-
1
,
error_message
=
str
(
e
),
...
...
@@ -400,12 +400,12 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
Test
Result: Test result object containing status and timing information
Case
Result: Test
case
result object containing status and timing information
"""
device_str
=
torch_device_map
[
device
]
# Initialize test result
test_result
=
Test
Result
(
# Initialize test
case
result
test_result
=
Case
Result
(
success
=
False
,
return_code
=-
1
,
# Default to failure
test_case
=
test_case
,
...
...
test/infinicore/framework/benchmark.py
View file @
12cde8eb
...
...
@@ -5,7 +5,7 @@ Benchmarking utilities for the InfiniCore testing framework
import
time
import
torch
import
infinicore
from
.utils
import
synchronize_device
from
.utils
.tensor_utils
import
synchronize_device
class
BenchmarkUtils
:
...
...
test/infinicore/framework/
test_case
.py
→
test/infinicore/framework/
entities
.py
View file @
12cde8eb
...
...
@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from
.tensor
import
TensorSpec
@
dataclass
class
TestResult
:
"""Test result data structure"""
success
:
bool
return_code
:
int
# 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time
:
float
=
0.0
torch_device_time
:
float
=
0.0
infini_host_time
:
float
=
0.0
infini_device_time
:
float
=
0.0
error_message
:
str
=
""
test_case
:
Any
=
None
device
:
Any
=
None
class
TestCase
:
"""Test case with all configuration included"""
...
...
test/infinicore/framework/executor.py
0 → 100644
View file @
12cde8eb
import
sys
import
importlib.util
from
io
import
StringIO
from
contextlib
import
contextmanager
from
.results
import
OperatorResult
,
TestSummary
@
contextmanager
def
capture_output
():
"""Context manager: captures stdout and stderr."""
new_out
,
new_err
=
StringIO
(),
StringIO
()
old_out
,
old_err
=
sys
.
stdout
,
sys
.
stderr
try
:
sys
.
stdout
,
sys
.
stderr
=
new_out
,
new_err
yield
new_out
,
new_err
finally
:
sys
.
stdout
,
sys
.
stderr
=
old_out
,
old_err
class
TestExecutor
:
def
execute
(
self
,
file_path
)
->
OperatorResult
:
result
=
OperatorResult
(
name
=
file_path
.
stem
)
try
:
# 1. Dynamically import the module
module
=
self
.
_import_module
(
file_path
)
# 2. Look for TestRunner
if
not
hasattr
(
module
,
"GenericTestRunner"
):
raise
ImportError
(
"No GenericTestRunner found in module"
)
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class
=
self
.
_find_test_class
(
module
)
if
not
test_class
:
raise
ImportError
(
"No BaseOperatorTest subclass found"
)
test_instance
=
test_class
()
runner_class
=
module
.
GenericTestRunner
runner
=
runner_class
(
test_instance
.
__class__
)
# 4. Execute and capture output
with
capture_output
()
as
(
out
,
err
):
success
,
internal_runner
=
runner
.
run
()
# 5. Populate results
result
.
success
=
success
result
.
stdout
=
out
.
getvalue
()
result
.
stderr
=
err
.
getvalue
()
# Extract detailed results from internal_runner
test_results
=
internal_runner
.
get_test_results
()
if
internal_runner
else
[]
test_summary
=
TestSummary
()
test_summary
.
process_operator_result
(
result
,
test_results
)
except
Exception
as
e
:
result
.
success
=
False
result
.
error_message
=
str
(
e
)
result
.
stderr
+=
f
"
\n
Executor Error:
{
str
(
e
)
}
"
result
.
return_code
=
-
1
return
result
def
_import_module
(
self
,
path
):
module_name
=
f
"op_test_
{
path
.
stem
}
"
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
if
not
spec
or
not
spec
.
loader
:
raise
ImportError
(
f
"Could not load spec from
{
path
}
"
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
sys
.
modules
[
module_name
]
=
module
spec
.
loader
.
exec_module
(
module
)
return
module
def
_find_test_class
(
self
,
module
):
for
attr_name
in
dir
(
module
):
attr
=
getattr
(
module
,
attr_name
)
if
isinstance
(
attr
,
type
)
and
hasattr
(
attr
,
"__bases__"
):
# Simple check for base class name
if
any
(
"BaseOperatorTest"
in
str
(
b
)
for
b
in
attr
.
__bases__
):
return
attr
return
None
test/infinicore/framework/reporter.py
deleted
100644 → 0
View file @
62fe6999
import
json
import
os
from
datetime
import
datetime
from
typing
import
List
,
Dict
,
Any
,
Union
from
dataclasses
import
is_dataclass
from
.base
import
TensorSpec
from
.devices
import
InfiniDeviceEnum
class
TestReporter
:
"""
Handles report generation and file saving for test results.
"""
@
staticmethod
def
prepare_report_entry
(
op_name
:
str
,
test_cases
:
List
[
Any
],
args
:
Any
,
op_paths
:
Dict
[
str
,
str
],
results_list
:
List
[
Any
]
)
->
List
[
Dict
[
str
,
Any
]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map
=
{}
if
isinstance
(
results_list
,
list
):
results_map
=
{
i
:
res
for
i
,
res
in
enumerate
(
results_list
)}
elif
isinstance
(
results_list
,
dict
):
results_map
=
results_list
else
:
results_map
=
{
0
:
results_list
}
if
results_list
else
{}
# 2. Global Args
global_args
=
{
k
:
getattr
(
args
,
k
)
for
k
in
[
"bench"
,
"num_prerun"
,
"num_iterations"
,
"verbose"
,
"debug"
]
if
hasattr
(
args
,
k
)
}
grouped_entries
:
Dict
[
int
,
Dict
[
str
,
Any
]]
=
{}
# 3. Iterate Test Cases
for
idx
,
tc
in
enumerate
(
test_cases
):
res
=
results_map
.
get
(
idx
)
dev_id
=
getattr
(
res
,
"device"
,
0
)
if
res
else
0
# --- A. Initialize Group ---
if
dev_id
not
in
grouped_entries
:
device_id_map
=
{
v
:
k
for
k
,
v
in
vars
(
InfiniDeviceEnum
).
items
()
if
not
k
.
startswith
(
"_"
)}
dev_str
=
device_id_map
.
get
(
dev_id
,
str
(
dev_id
))
grouped_entries
[
dev_id
]
=
{
"operator"
:
op_name
,
"device"
:
dev_str
,
"torch_op"
:
op_paths
.
get
(
"torch"
)
or
"unknown"
,
"infinicore_op"
:
op_paths
.
get
(
"infinicore"
)
or
"unknown"
,
"args"
:
global_args
,
"testcases"
:
[]
}
# --- B. Build Kwargs ---
display_kwargs
=
{}
# B1. Process existing kwargs
for
k
,
v
in
tc
.
kwargs
.
items
():
# Handle Inplace: "out": index -> "out": "input_name"
if
k
==
"out"
and
isinstance
(
v
,
int
):
if
0
<=
v
<
len
(
tc
.
inputs
):
display_kwargs
[
k
]
=
tc
.
inputs
[
v
].
name
else
:
display_kwargs
[
k
]
=
f
"Invalid_Index_
{
v
}
"
else
:
display_kwargs
[
k
]
=
(
TestReporter
.
_spec_to_dict
(
v
)
if
isinstance
(
v
,
TensorSpec
)
else
v
)
# B2. Inject Outputs into Kwargs
if
hasattr
(
tc
,
"output_specs"
)
and
tc
.
output_specs
:
for
i
,
spec
in
enumerate
(
tc
.
output_specs
):
display_kwargs
[
f
"out_
{
i
}
"
]
=
TestReporter
.
_spec_to_dict
(
spec
)
elif
tc
.
output_spec
:
if
"out"
not
in
display_kwargs
:
display_kwargs
[
"out"
]
=
TestReporter
.
_spec_to_dict
(
tc
.
output_spec
)
# --- C. Build Test Case Dictionary ---
case_data
=
{
"description"
:
tc
.
description
,
"inputs"
:
[
TestReporter
.
_spec_to_dict
(
i
)
for
i
in
tc
.
inputs
],
"kwargs"
:
display_kwargs
,
"comparison_target"
:
tc
.
comparison_target
,
"tolerance"
:
tc
.
tolerance
,
}
# --- D. Inject Result ---
if
res
:
case_data
[
"result"
]
=
TestReporter
.
_fmt_result
(
res
)
else
:
case_data
[
"result"
]
=
{
"status"
:
{
"success"
:
False
,
"error"
:
"No result"
}}
grouped_entries
[
dev_id
][
"testcases"
].
append
(
case_data
)
return
list
(
grouped_entries
.
values
())
@
staticmethod
def
save_all_results
(
save_path
:
str
,
total_results
:
List
[
Dict
[
str
,
Any
]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory
,
filename
=
os
.
path
.
split
(
save_path
)
name
,
ext
=
os
.
path
.
splitext
(
filename
)
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S_%f"
)[:
-
3
]
final_path
=
os
.
path
.
join
(
directory
,
f
"
{
name
}
_
{
timestamp
}{
ext
}
"
)
# Define indentation levels for cleaner code
indent_4
=
' '
*
4
indent_8
=
' '
*
8
indent_12
=
' '
*
12
indent_16
=
' '
*
16
indent_20
=
' '
*
20
print
(
f
"💾 Saving to:
{
final_path
}
"
)
try
:
with
open
(
final_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
"[
\n
"
)
for
i
,
entry
in
enumerate
(
total_results
):
f
.
write
(
f
"
{
indent_4
}
{{
\n
"
)
keys
=
list
(
entry
.
keys
())
for
j
,
key
in
enumerate
(
keys
):
val
=
entry
[
key
]
comma
=
","
if
j
<
len
(
keys
)
-
1
else
""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if
key
==
"testcases"
and
isinstance
(
val
,
list
):
f
.
write
(
f
'
{
indent_8
}
"
{
key
}
": [
\n
'
)
for
c_idx
,
case_item
in
enumerate
(
val
):
f
.
write
(
f
"
{
indent_12
}
{{
\n
"
)
case_keys
=
list
(
case_item
.
keys
())
for
k_idx
,
c_key
in
enumerate
(
case_keys
):
c_val
=
case_item
[
c_key
]
# [Logic A] Skip fields we merged manually after 'kwargs'
if
c_key
in
[
"comparison_target"
,
"tolerance"
]:
continue
# Check comma for standard logic (might be overridden below)
c_comma
=
","
if
k_idx
<
len
(
case_keys
)
-
1
else
""
# [Logic B] Handle 'kwargs' + Grouped Fields
if
c_key
==
"kwargs"
:
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter
.
_write_smart_field
(
f
,
c_key
,
c_val
,
indent_16
,
indent_20
,
close_comma
=
","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v
=
json
.
dumps
(
case_item
.
get
(
"comparison_target"
),
ensure_ascii
=
False
)
tol_v
=
json
.
dumps
(
case_item
.
get
(
"tolerance"
),
ensure_ascii
=
False
)
remaining_keys
=
[
k
for
k
in
case_keys
[
k_idx
+
1
:]
if
k
not
in
(
"comparison_target"
,
"tolerance"
)]
line_comma
=
","
if
remaining_keys
else
""
f
.
write
(
f
'
{
indent_16
}
"comparison_target":
{
cmp_v
}
, "tolerance":
{
tol_v
}{
line_comma
}
\n
'
)
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if
c_key
==
"inputs"
and
isinstance
(
c_val
,
list
):
TestReporter
.
_write_smart_field
(
f
,
c_key
,
c_val
,
indent_16
,
indent_20
,
close_comma
=
c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else
:
c_val_str
=
json
.
dumps
(
c_val
,
ensure_ascii
=
False
)
f
.
write
(
f
'
{
indent_16
}
"
{
c_key
}
":
{
c_val_str
}{
c_comma
}
\n
'
)
close_comma
=
","
if
c_idx
<
len
(
val
)
-
1
else
""
f
.
write
(
f
"
{
indent_12
}
}}
{
close_comma
}
\n
"
)
f
.
write
(
f
"
{
indent_8
}
]
{
comma
}
\n
"
)
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else
:
k_str
=
json
.
dumps
(
key
,
ensure_ascii
=
False
)
v_str
=
json
.
dumps
(
val
,
ensure_ascii
=
False
)
f
.
write
(
f
"
{
indent_8
}{
k_str
}
:
{
v_str
}{
comma
}
\n
"
)
if
i
<
len
(
total_results
)
-
1
:
f
.
write
(
f
"
{
indent_4
}
}},
\n
"
)
else
:
f
.
write
(
f
"
{
indent_4
}
}}
\n
"
)
f
.
write
(
"]
\n
"
)
print
(
f
" ✅ Saved (Structure Matched)."
)
except
Exception
as
e
:
import
traceback
;
traceback
.
print_exc
()
print
(
f
" ❌ Save failed:
{
e
}
"
)
# --- Internal Helpers ---
@
staticmethod
def
_write_smart_field
(
f
,
key
,
value
,
indent
,
sub_indent
,
close_comma
=
""
):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
if
len
(
compact_json
)
<=
180
:
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
compact_json
}{
close_comma
}
\n
'
)
return
# 2. Fill/Flow Mode
is_dict
=
isinstance
(
value
,
dict
)
open_char
=
'{'
if
is_dict
else
'['
close_char
=
'}'
if
is_dict
else
']'
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
open_char
}
'
)
# Normalize items for iteration
if
is_dict
:
items
=
list
(
value
.
items
())
else
:
items
=
value
# List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len
=
len
(
indent
)
+
len
(
f
'"
{
key
}
":
{
open_char
}
'
)
for
i
,
item
in
enumerate
(
items
):
# Format individual item string
if
is_dict
:
k
,
v
=
item
val_str
=
json
.
dumps
(
v
,
ensure_ascii
=
False
)
item_str
=
f
'"
{
k
}
":
{
val_str
}
'
else
:
item_str
=
json
.
dumps
(
item
,
ensure_ascii
=
False
)
is_last
=
(
i
==
len
(
items
)
-
1
)
item_comma
=
""
if
is_last
else
", "
# Predict new length: current + item + comma
if
current_len
+
len
(
item_str
)
+
len
(
item_comma
)
>
180
:
# Wrap to new line
f
.
write
(
f
'
\n
{
sub_indent
}
'
)
current_len
=
len
(
sub_indent
)
f
.
write
(
f
'
{
item_str
}{
item_comma
}
'
)
current_len
+=
len
(
item_str
)
+
len
(
item_comma
)
f
.
write
(
f
'
{
close_char
}{
close_comma
}
\n
'
)
@
staticmethod
def
_spec_to_dict
(
s
):
return
{
"name"
:
getattr
(
s
,
"name"
,
"unknown"
),
"shape"
:
list
(
s
.
shape
)
if
s
.
shape
else
None
,
"dtype"
:
str
(
s
.
dtype
).
split
(
"."
)[
-
1
],
"strides"
:
list
(
s
.
strides
)
if
s
.
strides
else
None
,
}
@
staticmethod
def
_fmt_result
(
res
):
if
not
(
is_dataclass
(
res
)
or
hasattr
(
res
,
"success"
)):
return
str
(
res
)
get_time
=
lambda
k
:
round
(
getattr
(
res
,
k
,
0.0
),
4
)
return
{
"status"
:
{
"success"
:
getattr
(
res
,
"success"
,
False
),
"error"
:
getattr
(
res
,
"error_message"
,
""
),
},
"perf_ms"
:
{
"torch"
:
{
"host"
:
get_time
(
"torch_host_time"
),
"device"
:
get_time
(
"torch_device_time"
),
},
"infinicore"
:
{
"host"
:
get_time
(
"infini_host_time"
),
"device"
:
get_time
(
"infini_device_time"
),
},
},
}
test/infinicore/framework/results.py
0 → 100644
View file @
12cde8eb
from
typing
import
List
,
Dict
,
Any
from
dataclasses
import
dataclass
,
is_dataclass
,
field
from
.devices
import
InfiniDeviceEnum
from
.tensor
import
TensorSpec
from
.utils.json_utils
import
save_json_report
@
dataclass
class
CaseResult
:
"""Test case result data structure"""
success
:
bool
return_code
:
int
# 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time
:
float
=
0.0
torch_device_time
:
float
=
0.0
infini_host_time
:
float
=
0.0
infini_device_time
:
float
=
0.0
error_message
:
str
=
""
test_case
:
Any
=
None
device
:
Any
=
None
@
dataclass
class
TestTiming
:
"""Stores performance timing metrics."""
torch_host
:
float
=
0.0
torch_device
:
float
=
0.0
infini_host
:
float
=
0.0
infini_device
:
float
=
0.0
# Added field to support the logic in your print_summary
operators_tested
:
int
=
0
@
dataclass
class
OperatorResult
:
"""Stores the execution results of a single operator."""
name
:
str
success
:
bool
=
False
return_code
:
int
=
-
1
error_message
:
str
=
""
stdout
:
str
=
""
stderr
:
str
=
""
timing
:
TestTiming
=
field
(
default_factory
=
TestTiming
)
@
property
def
status_icon
(
self
):
if
self
.
return_code
==
0
:
return
"✅"
if
self
.
return_code
==
-
2
:
return
"⏭️"
if
self
.
return_code
==
-
3
:
return
"⚠️"
return
"❌"
@
property
def
status_text
(
self
):
if
self
.
return_code
==
0
:
return
"PASSED"
if
self
.
return_code
==
-
2
:
return
"SKIPPED"
if
self
.
return_code
==
-
3
:
return
"PARTIAL"
return
"FAILED"
class
TestSummary
:
"""
Test Summary class:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
"""
def
__init__
(
self
,
verbose
=
False
,
bench_mode
=
None
):
self
.
verbose
=
verbose
self
.
bench_mode
=
bench_mode
self
.
report_entries
=
[]
# Cache for JSON report
# =========================================================
# Part 1: Result Aggregation
# =========================================================
def
process_operator_result
(
self
,
op_result
,
sub_results
:
List
):
"""
Updates the OperatorResult object in-place.
"""
if
not
sub_results
:
op_result
.
return_code
=
-
1
return
# 1. Analyze Return Code (Status)
if
op_result
.
success
:
op_result
.
return_code
=
0
else
:
has_failures
=
any
(
r
.
return_code
==
-
1
for
r
in
sub_results
)
has_partial
=
any
(
r
.
return_code
==
-
3
for
r
in
sub_results
)
has_skipped
=
any
(
r
.
return_code
==
-
2
for
r
in
sub_results
)
if
has_failures
:
op_result
.
return_code
=
-
1
elif
has_partial
:
op_result
.
return_code
=
-
3
elif
has_skipped
:
op_result
.
return_code
=
-
2
else
:
op_result
.
return_code
=
-
1
# 2. Extract Timing (Aggregation)
t
=
op_result
.
timing
t
.
torch_host
=
sum
(
r
.
torch_host_time
for
r
in
sub_results
)
t
.
torch_device
=
sum
(
r
.
torch_device_time
for
r
in
sub_results
)
t
.
infini_host
=
sum
(
r
.
infini_host_time
for
r
in
sub_results
)
t
.
infini_device
=
sum
(
r
.
infini_device_time
for
r
in
sub_results
)
t
.
operators_tested
=
len
(
sub_results
)
# =========================================================
# Part 2: Console Output (View)
# =========================================================
def
list_tests
(
self
,
discoverer
):
ops_dir
=
discoverer
.
ops_dir
operators
=
discoverer
.
get_available_operators
()
if
operators
:
print
(
f
"Available operator test files in
{
ops_dir
}
:"
)
for
operator
in
operators
:
print
(
f
" -
{
operator
}
"
)
print
(
f
"
\n
Total:
{
len
(
operators
)
}
operators"
)
else
:
print
(
f
"No valid operator tests found in
{
ops_dir
}
"
)
raw_files
=
discoverer
.
get_raw_python_files
()
if
raw_files
:
print
(
f
"
\n
💡 Debug Hint: Found Python files but they are not valid tests:"
)
print
(
f
"
{
raw_files
}
"
)
def
print_header
(
self
,
ops_dir
,
count
):
print
(
f
"InfiniCore Operator Test Runner"
)
print
(
f
"Directory:
{
ops_dir
}
"
)
print
(
f
"Tests found:
{
count
}
\n
"
)
def
print_live_result
(
self
,
result
):
print
(
f
"
{
result
.
status_icon
}
{
result
.
name
}
:
{
result
.
status_text
}
(code:
{
result
.
return_code
}
)"
)
if
result
.
stdout
:
print
(
result
.
stdout
.
rstrip
())
if
result
.
stderr
:
print
(
"
\n
STDERR:"
,
result
.
stderr
.
rstrip
())
if
result
.
error_message
:
print
(
f
"💥 Error:
{
result
.
error_message
}
"
)
if
result
.
stdout
or
result
.
stderr
or
self
.
verbose
:
print
(
"-"
*
40
)
def
print_summary
(
self
,
results
,
cumulative_timing
,
ops_dir
,
total_expected
=
0
):
print
(
f
"
\n
{
'='
*
80
}
\n
CUMULATIVE TEST SUMMARY
\n
{
'='
*
80
}
"
)
passed
=
[
r
for
r
in
results
if
r
.
return_code
==
0
]
failed
=
[
r
for
r
in
results
if
r
.
return_code
==
-
1
]
skipped
=
[
r
for
r
in
results
if
r
.
return_code
==
-
2
]
partial
=
[
r
for
r
in
results
if
r
.
return_code
==
-
3
]
total
=
len
(
results
)
print
(
f
"Total tests run:
{
total
}
"
)
if
total_expected
>
0
and
total
<
total_expected
:
print
(
f
"Total tests expected:
{
total_expected
}
"
)
print
(
f
"Tests not executed:
{
total_expected
-
total
}
"
)
print
(
f
"Passed:
{
len
(
passed
)
}
"
)
print
(
f
"Failed:
{
len
(
failed
)
}
"
)
if
skipped
:
print
(
f
"Skipped:
{
len
(
skipped
)
}
"
)
if
partial
:
print
(
f
"Partial:
{
len
(
partial
)
}
"
)
# 1. Benchmark
if
cumulative_timing
:
self
.
_print_timing
(
cumulative_timing
)
# 2. Lists
if
passed
:
self
.
_print_op_list
(
"✅ PASSED OPERATORS"
,
passed
)
else
:
print
(
f
"
\n
✅ PASSED OPERATORS: None"
)
if
failed
:
self
.
_print_op_list
(
"❌ FAILED OPERATORS"
,
failed
)
if
skipped
:
self
.
_print_op_list
(
"⏭️ SKIPPED OPERATORS"
,
skipped
)
if
partial
:
self
.
_print_op_list
(
"⚠️ PARTIAL IMPLEMENTATIONS"
,
partial
)
# 3. Verdict
if
total
>
0
:
executed_tests
=
total
-
len
(
skipped
)
if
executed_tests
>
0
:
success_rate
=
len
(
passed
)
/
executed_tests
*
100
print
(
f
"
\n
Success rate:
{
success_rate
:.
1
f
}
%"
)
if
not
failed
:
if
skipped
or
partial
:
print
(
f
"
\n
⚠️ Tests completed with some operators not fully implemented"
)
else
:
print
(
f
"
\n
🎉 All tests passed!"
)
else
:
print
(
f
"
\n
❌
{
len
(
failed
)
}
tests failed"
)
if
not
failed
and
(
skipped
or
partial
):
print
(
f
"
\n
⚠️ Note: Some operators are not fully implemented"
)
print
(
f
" Run individual tests for details on missing implementations"
)
if
self
.
verbose
and
failed
:
print
(
f
"
\n
💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
for
r
in
failed
[:
3
]:
file_path
=
ops_dir
/
(
r
.
name
+
".py"
)
print
(
f
" python
{
file_path
}
--verbose"
)
if
len
(
failed
)
>
3
:
print
(
f
" ... (and
{
len
(
failed
)
-
3
}
others)"
)
return
len
(
failed
)
==
0
def
_print_timing
(
self
,
t
):
print
(
f
"
{
'-'
*
40
}
"
)
if
hasattr
(
t
,
"operators_tested"
)
and
t
.
operators_tested
>
0
:
print
(
f
"BENCHMARK SUMMARY (
{
t
.
operators_tested
}
cases):"
)
if
self
.
bench_mode
in
[
"host"
,
"both"
]:
print
(
f
" [Host] PyTorch:
{
t
.
torch_host
:
10.3
f
}
ms"
)
print
(
f
" [Host] InfiniCore:
{
t
.
infini_host
:
10.3
f
}
ms"
)
if
self
.
bench_mode
in
[
"device"
,
"both"
]:
print
(
f
" [Device] PyTorch:
{
t
.
torch_device
:
10.3
f
}
ms"
)
print
(
f
" [Device] InfiniCore:
{
t
.
infini_device
:
10.3
f
}
ms"
)
print
(
f
"
{
'-'
*
40
}
"
)
def
_print_op_list
(
self
,
title
,
result_list
):
print
(
f
"
\n
{
title
}
(
{
len
(
result_list
)
}
):"
)
names
=
[
r
.
name
for
r
in
result_list
]
for
i
in
range
(
0
,
len
(
names
),
10
):
print
(
" "
+
", "
.
join
(
names
[
i
:
i
+
10
]))
# =========================================================
# Part 3: Report Generation
# =========================================================
def
collect_report_entry
(
self
,
op_name
,
test_cases
,
args
,
op_paths
,
results_list
):
"""
Prepares the data and adds it to the internal list.
"""
entry
=
self
.
_prepare_entry_logic
(
op_name
,
test_cases
,
args
,
op_paths
,
results_list
)
self
.
report_entries
.
extend
(
entry
)
def
save_report
(
self
,
save_path
):
"""
Delegates the actual writing to save_json_report.
"""
if
not
self
.
report_entries
:
return
# Call the external utility
save_json_report
(
save_path
,
self
.
report_entries
)
def
_prepare_entry_logic
(
self
,
op_name
,
test_cases
,
args
,
op_paths
,
results_list
):
"""
Combines static test case info with dynamic execution results.
Refactored to reduce duplication.
"""
# 1. Normalize results
results_map
=
(
results_list
if
isinstance
(
results_list
,
dict
)
else
{
i
:
res
for
i
,
res
in
enumerate
(
results_list
or
[])}
)
# 2. Global Args
global_args
=
{
k
:
getattr
(
args
,
k
)
for
k
in
[
"bench"
,
"num_prerun"
,
"num_iterations"
,
"verbose"
,
"debug"
]
if
hasattr
(
args
,
k
)
}
grouped_entries
=
{}
# Cache device enum map
device_id_map
=
{
v
:
k
for
k
,
v
in
vars
(
InfiniDeviceEnum
).
items
()
if
not
k
.
startswith
(
"_"
)
}
for
idx
,
tc
in
enumerate
(
test_cases
):
res
=
results_map
.
get
(
idx
)
dev_id
=
getattr
(
res
,
"device"
,
0
)
if
res
else
0
# --- A. Initialize Group ---
if
dev_id
not
in
grouped_entries
:
grouped_entries
[
dev_id
]
=
{
"operator"
:
op_name
,
"device"
:
device_id_map
.
get
(
dev_id
,
str
(
dev_id
)),
"torch_op"
:
op_paths
.
get
(
"torch"
,
"unknown"
),
"infinicore_op"
:
op_paths
.
get
(
"infinicore"
,
"unknown"
),
"args"
:
global_args
,
"testcases"
:
[],
}
# --- B. Helpers for Spec Processing ---
def
process_spec
(
spec
,
default_name
):
final_name
=
self
.
_resolve_name
(
spec
,
default_name
)
# Call internal method (no need for external converters file)
return
self
.
_spec_to_dict
(
spec
,
name
=
final_name
)
# --- C. Build Inputs ---
processed_inputs
=
[
process_spec
(
inp
,
f
"in_
{
i
}
"
)
for
i
,
inp
in
enumerate
(
tc
.
inputs
)
]
# --- D. Build Kwargs ---
display_kwargs
=
{}
for
k
,
v
in
tc
.
kwargs
.
items
():
if
k
==
"out"
and
isinstance
(
v
,
int
):
# Handle Inplace Index
if
0
<=
v
<
len
(
tc
.
inputs
):
display_kwargs
[
k
]
=
self
.
_resolve_name
(
tc
.
inputs
[
v
],
f
"in_
{
v
}
"
)
else
:
display_kwargs
[
k
]
=
f
"Invalid_Index_
{
v
}
"
elif
isinstance
(
v
,
TensorSpec
):
display_kwargs
[
k
]
=
process_spec
(
v
,
v
.
name
)
else
:
display_kwargs
[
k
]
=
v
# --- E. Inject Outputs ---
if
getattr
(
tc
,
"output_specs"
,
None
):
for
i
,
spec
in
enumerate
(
tc
.
output_specs
):
display_kwargs
[
f
"out_
{
i
}
"
]
=
process_spec
(
spec
,
f
"out_
{
i
}
"
)
elif
tc
.
output_spec
and
"out"
not
in
display_kwargs
:
display_kwargs
[
"out"
]
=
process_spec
(
tc
.
output_spec
,
"out"
)
# --- F. Assemble Case Data ---
case_data
=
{
"description"
:
tc
.
description
,
"inputs"
:
processed_inputs
,
"kwargs"
:
display_kwargs
,
"comparison_target"
:
tc
.
comparison_target
,
"tolerance"
:
tc
.
tolerance
,
"result"
:
(
self
.
_fmt_result
(
res
)
if
res
else
{
"status"
:
{
"success"
:
False
,
"error"
:
"No result"
}}
),
}
grouped_entries
[
dev_id
][
"testcases"
].
append
(
case_data
)
return
list
(
grouped_entries
.
values
())
# --- Internal Helpers ---
def
_resolve_name
(
self
,
obj
,
default_name
):
return
getattr
(
obj
,
"name"
,
None
)
or
default_name
def
_spec_to_dict
(
self
,
s
,
name
=
None
):
return
{
"name"
:
name
if
name
else
getattr
(
s
,
"name"
,
"unknown"
),
"shape"
:
list
(
s
.
shape
)
if
s
.
shape
else
None
,
"dtype"
:
str
(
s
.
dtype
).
split
(
"."
)[
-
1
],
"strides"
:
list
(
s
.
strides
)
if
s
.
strides
else
None
,
}
def
_fmt_result
(
self
,
res
):
if
not
(
is_dataclass
(
res
)
or
hasattr
(
res
,
"success"
)):
return
str
(
res
)
get_time
=
lambda
k
:
round
(
getattr
(
res
,
k
,
0.0
),
4
)
return
{
"status"
:
{
"success"
:
getattr
(
res
,
"success"
,
False
),
"error"
:
getattr
(
res
,
"error_message"
,
""
),
},
"perf_ms"
:
{
"torch"
:
{
"host"
:
get_time
(
"torch_host_time"
),
"device"
:
get_time
(
"torch_device_time"
),
},
"infinicore"
:
{
"host"
:
get_time
(
"infini_host_time"
),
"device"
:
get_time
(
"infini_device_time"
),
},
},
}
test/infinicore/framework/runner.py
View file @
12cde8eb
...
...
@@ -7,7 +7,7 @@ import os
import
inspect
import
re
from
.
import
TestConfig
,
TestRunner
,
get_args
,
get_test_devices
from
.re
porter
import
Test
Reporter
from
.re
sults
import
Test
Summary
class
GenericTestRunner
:
...
...
@@ -89,7 +89,8 @@ class GenericTestRunner:
op_paths
=
{
"torch"
:
t_path
,
"infinicore"
:
i_path
}
# 2. Generate Report Entries
entries
=
TestReporter
.
prepare_report_entry
(
test_summary
=
TestSummary
()
entries
=
test_summary
.
collect_report_entry
(
op_name
=
self
.
operator_test
.
operator_name
,
test_cases
=
self
.
operator_test
.
test_cases
,
args
=
self
.
args
,
...
...
@@ -98,7 +99,7 @@ class GenericTestRunner:
)
# 4. Save to File
T
est
Reporter
.
save_all_results
(
self
.
args
.
save
,
entries
)
t
est
_summary
.
save_report
(
self
.
args
.
save
)
except
Exception
as
e
:
import
traceback
...
...
test/infinicore/framework/tensor.py
View file @
12cde8eb
...
...
@@ -3,7 +3,7 @@ import math
from
pathlib
import
Path
from
.datatypes
import
to_torch_dtype
from
.devices
import
torch_device_map
from
.utils
import
is_integer_dtype
,
is_complex_dtype
from
.utils
.tensor_utils
import
is_integer_dtype
,
is_complex_dtype
class
TensorInitializer
:
...
...
test/infinicore/framework/utils/__init__.py
0 → 100644
View file @
12cde8eb
test/infinicore/framework/utils.py
→
test/infinicore/framework/utils
/compare_utils
.py
View file @
12cde8eb
import
torch
import
time
import
infinicore
import
numpy
as
np
from
.datatypes
import
to_infinicore_dtype
,
to_torch_dtype
def
synchronize_device
(
torch_device
):
"""Device synchronization"""
if
torch_device
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
()
elif
torch_device
==
"mlu"
:
torch
.
mlu
.
synchronize
()
elif
torch_device
==
"musa"
:
torch
.
musa
.
synchronize
()
def
debug
(
actual
,
desired
,
atol
=
0
,
rtol
=
1e-2
,
equal_nan
=
False
,
verbose
=
True
):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if
actual
.
is_complex
()
or
desired
.
is_complex
():
actual
=
torch
.
view_as_real
(
actual
)
desired
=
torch
.
view_as_real
(
desired
)
elif
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
actual
=
actual
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
)
import
numpy
as
np
np
.
testing
.
assert_allclose
(
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
)
def
print_discrepancy
(
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
equal_nan
=
True
,
verbose
=
True
):
"""Print detailed tensor differences"""
if
actual
.
shape
!=
expected
.
shape
:
raise
ValueError
(
"Tensors must have the same shape to compare."
)
import
torch
import
sys
is_terminal
=
sys
.
stdout
.
isatty
()
actual_isnan
=
torch
.
isnan
(
actual
)
expected_isnan
=
torch
.
isnan
(
expected
)
# Calculate difference mask
nan_mismatch
=
(
actual_isnan
^
expected_isnan
if
equal_nan
else
actual_isnan
|
expected_isnan
)
diff_mask
=
nan_mismatch
|
(
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
))
)
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
delta
=
actual
-
expected
# Display formatting
col_width
=
[
18
,
20
,
20
,
20
]
decimal_places
=
[
0
,
12
,
12
,
12
]
total_width
=
sum
(
col_width
)
+
sum
(
decimal_places
)
def
add_color
(
text
,
color_code
):
if
is_terminal
:
return
f
"
\033
[
{
color_code
}
m
{
text
}
\033
[0m"
else
:
return
text
if
verbose
:
for
idx
in
diff_indices
:
index_tuple
=
tuple
(
idx
.
tolist
())
actual_str
=
f
"
{
actual
[
index_tuple
]:
<
{
col_width
[
1
]
}
.
{
decimal_places
[
1
]
}
f
}
"
expected_str
=
(
f
"
{
expected
[
index_tuple
]:
<
{
col_width
[
2
]
}
.
{
decimal_places
[
2
]
}
f
}
"
)
delta_str
=
f
"
{
delta
[
index_tuple
]:
<
{
col_width
[
3
]
}
.
{
decimal_places
[
3
]
}
f
}
"
print
(
f
" > Index:
{
str
(
index_tuple
):
<
{
col_width
[
0
]
}}
"
f
"actual:
{
add_color
(
actual_str
,
31
)
}
"
f
"expect:
{
add_color
(
expected_str
,
32
)
}
"
f
"delta:
{
add_color
(
delta_str
,
33
)
}
"
)
print
(
f
" - Actual dtype:
{
actual
.
dtype
}
"
)
print
(
f
" - Desired dtype:
{
expected
.
dtype
}
"
)
print
(
f
" - Atol:
{
atol
}
"
)
print
(
f
" - Rtol:
{
rtol
}
"
)
print
(
f
" - Equal NaN:
{
equal_nan
}
"
)
print
(
f
" - Mismatched elements:
{
len
(
diff_indices
)
}
/
{
actual
.
numel
()
}
(
{
len
(
diff_indices
)
/
actual
.
numel
()
*
100
}
%)"
)
print
(
f
" - Min(actual) :
{
torch
.
min
(
actual
):
<
{
col_width
[
1
]
}}
| Max(actual) :
{
torch
.
max
(
actual
):
<
{
col_width
[
2
]
}}
"
)
print
(
f
" - Min(desired):
{
torch
.
min
(
expected
):
<
{
col_width
[
1
]
}}
| Max(desired):
{
torch
.
max
(
expected
):
<
{
col_width
[
2
]
}}
"
)
print
(
f
" - Min(delta) :
{
torch
.
min
(
delta
):
<
{
col_width
[
1
]
}}
| Max(delta) :
{
torch
.
max
(
delta
):
<
{
col_width
[
2
]
}}
"
)
print
(
"-"
*
total_width
)
return
diff_indices
def
get_tolerance
(
tolerance_map
,
tensor_dtype
,
default_atol
=
0
,
default_rtol
=
1e-3
):
"""
Get tolerance settings based on data type
"""
tolerance
=
tolerance_map
.
get
(
tensor_dtype
,
{
"atol"
:
default_atol
,
"rtol"
:
default_rtol
}
)
return
tolerance
[
"atol"
],
tolerance
[
"rtol"
]
def
clone_torch_tensor
(
torch_tensor
):
cloned
=
torch_tensor
.
clone
().
detach
()
if
not
torch_tensor
.
is_contiguous
():
cloned
=
rearrange_tensor
(
cloned
,
torch_tensor
.
stride
())
return
cloned
def
infinicore_tensor_from_torch
(
torch_tensor
):
infini_device
=
infinicore
.
device
(
torch_tensor
.
device
.
type
,
0
)
if
torch_tensor
.
is_contiguous
():
return
infinicore
.
from_blob
(
torch_tensor
.
data_ptr
(),
list
(
torch_tensor
.
shape
),
dtype
=
to_infinicore_dtype
(
torch_tensor
.
dtype
),
device
=
infini_device
,
)
else
:
return
infinicore
.
strided_from_blob
(
torch_tensor
.
data_ptr
(),
list
(
torch_tensor
.
shape
),
list
(
torch_tensor
.
stride
()),
dtype
=
to_infinicore_dtype
(
torch_tensor
.
dtype
),
device
=
infini_device
,
)
def
convert_infinicore_to_torch
(
infini_result
):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini
=
torch
.
zeros
(
infini_result
.
shape
,
dtype
=
to_torch_dtype
(
infini_result
.
dtype
),
device
=
infini_result
.
device
.
type
,
)
if
not
infini_result
.
is_contiguous
():
torch_result_from_infini
=
rearrange_tensor
(
torch_result_from_infini
,
infini_result
.
stride
()
)
temp_tensor
=
infinicore_tensor_from_torch
(
torch_result_from_infini
)
temp_tensor
.
copy_
(
infini_result
)
return
torch_result_from_infini
import
sys
from
..datatypes
import
to_torch_dtype
from
.tensor_utils
import
(
convert_infinicore_to_torch
,
is_integer_dtype
,
is_complex_dtype
,
)
def
compare_results
(
...
...
@@ -351,89 +189,104 @@ def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
return
compare_test_results
def
rear
ran
g
e_tensor
(
tensor
,
new_strides
):
def
get_tolerance
(
tole
ran
c
e_
map
,
tensor
_dtype
,
default_atol
=
0
,
default_rtol
=
1e-3
):
"""
G
iven a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
G
et tolerance settings based on data type
"""
import
torch
tolerance
=
tolerance_map
.
get
(
tensor_dtype
,
{
"atol"
:
default_atol
,
"rtol"
:
default_rtol
}
)
return
tolerance
[
"atol"
],
tolerance
[
"rtol"
]
def
debug
(
actual
,
desired
,
atol
=
0
,
rtol
=
1e-2
,
equal_nan
=
False
,
verbose
=
True
):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if
actual
.
is_complex
()
or
desired
.
is_complex
():
actual
=
torch
.
view_as_real
(
actual
)
desired
=
torch
.
view_as_real
(
desired
)
elif
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
actual
=
actual
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
)
shape
=
tensor
.
shape
new_size
=
[
0
]
*
len
(
shape
)
left
=
0
right
=
0
for
i
in
range
(
len
(
shape
)):
if
new_strides
[
i
]
>=
0
:
new_size
[
i
]
=
(
shape
[
i
]
-
1
)
*
new_strides
[
i
]
+
1
right
+=
new_strides
[
i
]
*
(
shape
[
i
]
-
1
)
else
:
# TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise
ValueError
(
"Negative strides are not supported yet"
)
# Create a new tensor with zeros
new_tensor
=
torch
.
zeros
(
(
right
-
left
+
1
,),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
import
numpy
as
np
np
.
testing
.
assert_allclose
(
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
)
# Generate indices for original tensor based on original strides
indices
=
[
torch
.
arange
(
s
)
for
s
in
shape
]
mesh
=
torch
.
meshgrid
(
*
indices
,
indexing
=
"ij"
)
# Flatten indices for linear indexing
linear_indices
=
[
m
.
flatten
()
for
m
in
mesh
]
def
print_discrepancy
(
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
equal_nan
=
True
,
verbose
=
True
):
"""Print detailed tensor differences"""
if
actual
.
shape
!=
expected
.
shape
:
raise
ValueError
(
"Tensors must have the same shape to compare."
)
# Calculate new positions based on new strides
new_positions
=
sum
(
linear_indices
[
i
]
*
new_strides
[
i
]
for
i
in
range
(
len
(
shape
))
).
to
(
tensor
.
device
)
offset
=
-
left
new_positions
+=
offset
import
torch
import
sys
is_terminal
=
sys
.
stdout
.
isatty
(
)
actual_isnan
=
torch
.
isnan
(
actual
)
expected_isnan
=
torch
.
isnan
(
expected
)
# Copy the original data to the new tensor
new_tensor
.
reshape
(
-
1
).
index_add_
(
0
,
new_positions
,
tensor
.
reshape
(
-
1
))
new_tensor
.
set_
(
new_tensor
.
untyped_storage
(),
offset
,
shape
,
tuple
(
new_strides
))
# Calculate difference mask
nan_mismatch
=
(
actual_isnan
^
expected_isnan
if
equal_nan
else
actual_isnan
|
expected_isnan
)
diff_mask
=
nan_mismatch
|
(
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
))
)
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
delta
=
actual
-
expected
return
new_tensor
# Display formatting
col_width
=
[
18
,
20
,
20
,
20
]
decimal_places
=
[
0
,
12
,
12
,
12
]
total_width
=
sum
(
col_width
)
+
sum
(
decimal_places
)
def
add_color
(
text
,
color_code
):
if
is_terminal
:
return
f
"
\033
[
{
color_code
}
m
{
text
}
\033
[0m"
else
:
return
text
def
is_broadcast
(
strides
):
"""
Check if strides indicate a broadcasted tensor
if
verbose
:
for
idx
in
diff_indices
:
index_tuple
=
tuple
(
idx
.
tolist
())
actual_str
=
f
"
{
actual
[
index_tuple
]:
<
{
col_width
[
1
]
}
.
{
decimal_places
[
1
]
}
f
}
"
expected_str
=
(
f
"
{
expected
[
index_tuple
]:
<
{
col_width
[
2
]
}
.
{
decimal_places
[
2
]
}
f
}
"
)
delta_str
=
f
"
{
delta
[
index_tuple
]:
<
{
col_width
[
3
]
}
.
{
decimal_places
[
3
]
}
f
}
"
print
(
f
" > Index:
{
str
(
index_tuple
):
<
{
col_width
[
0
]
}}
"
f
"actual:
{
add_color
(
actual_str
,
31
)
}
"
f
"expect:
{
add_color
(
expected_str
,
32
)
}
"
f
"delta:
{
add_color
(
delta_str
,
33
)
}
"
)
Args:
strides: Tensor strides or None
print
(
f
" - Actual dtype:
{
actual
.
dtype
}
"
)
print
(
f
" - Desired dtype:
{
expected
.
dtype
}
"
)
print
(
f
" - Atol:
{
atol
}
"
)
print
(
f
" - Rtol:
{
rtol
}
"
)
print
(
f
" - Equal NaN:
{
equal_nan
}
"
)
print
(
f
" - Mismatched elements:
{
len
(
diff_indices
)
}
/
{
actual
.
numel
()
}
(
{
len
(
diff_indices
)
/
actual
.
numel
()
*
100
}
%)"
)
print
(
f
" - Min(actual) :
{
torch
.
min
(
actual
):
<
{
col_width
[
1
]
}}
| Max(actual) :
{
torch
.
max
(
actual
):
<
{
col_width
[
2
]
}}
"
)
print
(
f
" - Min(desired):
{
torch
.
min
(
expected
):
<
{
col_width
[
1
]
}}
| Max(desired):
{
torch
.
max
(
expected
):
<
{
col_width
[
2
]
}}
"
)
print
(
f
" - Min(delta) :
{
torch
.
min
(
delta
):
<
{
col_width
[
1
]
}}
| Max(delta) :
{
torch
.
max
(
delta
):
<
{
col_width
[
2
]
}}
"
)
print
(
"-"
*
total_width
)
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if
strides
is
None
:
return
False
return
any
(
s
==
0
for
s
in
strides
)
def
is_integer_dtype
(
dtype
):
"""Check if dtype is integer type"""
return
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
,
torch
.
uint8
,
torch
.
bool
,
]
def
is_complex_dtype
(
dtype
):
"""Check if dtype is complex type"""
return
dtype
in
[
torch
.
complex64
,
torch
.
complex128
]
def
is_floating_dtype
(
dtype
):
"""Check if dtype is floating-point type"""
return
dtype
in
[
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
bfloat16
,
]
return
diff_indices
test/infinicore/framework/utils/json_utils.py
0 → 100644
View file @
12cde8eb
import
json
import
os
from
datetime
import
datetime
def
save_json_report
(
save_path
,
total_results
):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
"""
directory
,
filename
=
os
.
path
.
split
(
save_path
)
name
,
ext
=
os
.
path
.
splitext
(
filename
)
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S_%f"
)[:
-
3
]
final_path
=
os
.
path
.
join
(
directory
,
f
"
{
name
}
_
{
timestamp
}{
ext
}
"
)
# Define Indentation
I4
,
I8
,
I12
,
I16
,
I20
=
" "
*
4
,
" "
*
8
,
" "
*
12
,
" "
*
16
,
" "
*
20
print
(
f
"💾 Saving to:
{
final_path
}
"
)
# Helper for JSON stringify to avoid repetition
def
_to_json
(
obj
):
return
json
.
dumps
(
obj
,
ensure_ascii
=
False
)
try
:
with
open
(
final_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
"[
\n
"
)
for
i
,
entry
in
enumerate
(
total_results
):
f
.
write
(
f
"
{
I4
}
{{
\n
"
)
keys
=
list
(
entry
.
keys
())
for
j
,
key
in
enumerate
(
keys
):
val
=
entry
[
key
]
comma
=
","
if
j
<
len
(
keys
)
-
1
else
""
# Special handling for 'testcases' list
if
key
==
"testcases"
and
isinstance
(
val
,
list
):
f
.
write
(
f
'
{
I8
}
"
{
key
}
": [
\n
'
)
for
c_idx
,
case_item
in
enumerate
(
val
):
f
.
write
(
f
"
{
I12
}
{{
\n
"
)
case_keys
=
list
(
case_item
.
keys
())
# Filter out keys that we handle specially at the end
standard_keys
=
[
k
for
k
in
case_keys
if
k
not
in
[
"comparison_target"
,
"tolerance"
]
]
for
k_idx
,
c_key
in
enumerate
(
standard_keys
):
c_val
=
case_item
[
c_key
]
# Determine comma logic
c_comma
=
(
","
if
k_idx
<
len
(
standard_keys
)
-
1
or
"comparison_target"
in
case_item
else
""
)
if
c_key
in
[
"kwargs"
,
"inputs"
]:
_write_field
(
f
,
c_key
,
c_val
,
I16
,
I20
,
close_comma
=
c_comma
)
else
:
f
.
write
(
f
'
{
I16
}
"
{
c_key
}
":
{
_to_json
(
c_val
)
}{
c_comma
}
\n
'
)
# Handle trailing comparison/tolerance fields uniformly
if
"comparison_target"
in
case_item
:
cmp
=
_to_json
(
case_item
.
get
(
"comparison_target"
))
tol
=
_to_json
(
case_item
.
get
(
"tolerance"
))
f
.
write
(
f
'
{
I16
}
"comparison_target":
{
cmp
}
, "tolerance":
{
tol
}
\n
'
)
close_case
=
","
if
c_idx
<
len
(
val
)
-
1
else
""
f
.
write
(
f
"
{
I12
}
}}
{
close_case
}
\n
"
)
f
.
write
(
f
"
{
I8
}
]
{
comma
}
\n
"
)
else
:
# Standard top-level fields
f
.
write
(
f
"
{
I8
}{
_to_json
(
key
)
}
:
{
_to_json
(
val
)
}{
comma
}
\n
"
)
close_entry
=
"},"
if
i
<
len
(
total_results
)
-
1
else
"}"
f
.
write
(
f
"
{
I4
}{
close_entry
}
\n
"
)
f
.
write
(
"]
\n
"
)
print
(
f
" ✅ Saved."
)
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
print
(
f
" ❌ Save failed:
{
e
}
"
)
def
_write_field
(
f
,
key
,
value
,
indent
,
sub_indent
,
close_comma
=
""
):
"""
Internal Helper: Write a JSON field with wrapping.
"""
# 1. Try Compact Mode
compact_json
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
if
len
(
compact_json
)
<=
180
:
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
compact_json
}{
close_comma
}
\n
'
)
return
# 2. Fill/Flow Mode
is_dict
=
isinstance
(
value
,
dict
)
open_char
=
"{"
if
is_dict
else
"["
close_char
=
"}"
if
is_dict
else
"]"
f
.
write
(
f
'
{
indent
}
"
{
key
}
":
{
open_char
}
'
)
if
is_dict
:
items
=
list
(
value
.
items
())
else
:
items
=
value
current_len
=
len
(
indent
)
+
len
(
f
'"
{
key
}
":
{
open_char
}
'
)
for
i
,
item
in
enumerate
(
items
):
if
is_dict
:
k
,
v
=
item
val_str
=
json
.
dumps
(
v
,
ensure_ascii
=
False
)
item_str
=
f
'"
{
k
}
":
{
val_str
}
'
else
:
item_str
=
json
.
dumps
(
item
,
ensure_ascii
=
False
)
is_last
=
i
==
len
(
items
)
-
1
item_comma
=
""
if
is_last
else
", "
if
current_len
+
len
(
item_str
)
+
len
(
item_comma
)
>
180
:
f
.
write
(
f
"
\n
{
sub_indent
}
"
)
current_len
=
len
(
sub_indent
)
f
.
write
(
f
"
{
item_str
}{
item_comma
}
"
)
current_len
+=
len
(
item_str
)
+
len
(
item_comma
)
f
.
write
(
f
"
{
close_char
}{
close_comma
}
\n
"
)
test/infinicore/framework/utils/tensor_utils.py
0 → 100644
View file @
12cde8eb
import
torch
import
infinicore
from
..datatypes
import
to_infinicore_dtype
,
to_torch_dtype
# =================================================================
# Device & Synchronization
# =================================================================
def
synchronize_device
(
torch_device
):
"""Device synchronization"""
if
torch_device
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
()
elif
torch_device
==
"mlu"
:
torch
.
mlu
.
synchronize
()
elif
torch_device
==
"musa"
:
torch
.
musa
.
synchronize
()
# =================================================================
# Tensor Operations & Conversions
# =================================================================
def
clone_torch_tensor
(
torch_tensor
):
cloned
=
torch_tensor
.
clone
().
detach
()
if
not
torch_tensor
.
is_contiguous
():
cloned
=
rearrange_tensor
(
cloned
,
torch_tensor
.
stride
())
return
cloned
def
infinicore_tensor_from_torch
(
torch_tensor
):
infini_device
=
infinicore
.
device
(
torch_tensor
.
device
.
type
,
0
)
if
torch_tensor
.
is_contiguous
():
return
infinicore
.
from_blob
(
torch_tensor
.
data_ptr
(),
list
(
torch_tensor
.
shape
),
dtype
=
to_infinicore_dtype
(
torch_tensor
.
dtype
),
device
=
infini_device
,
)
else
:
return
infinicore
.
strided_from_blob
(
torch_tensor
.
data_ptr
(),
list
(
torch_tensor
.
shape
),
list
(
torch_tensor
.
stride
()),
dtype
=
to_infinicore_dtype
(
torch_tensor
.
dtype
),
device
=
infini_device
,
)
def
convert_infinicore_to_torch
(
infini_result
):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini
=
torch
.
zeros
(
infini_result
.
shape
,
dtype
=
to_torch_dtype
(
infini_result
.
dtype
),
device
=
infini_result
.
device
.
type
,
)
if
not
infini_result
.
is_contiguous
():
torch_result_from_infini
=
rearrange_tensor
(
torch_result_from_infini
,
infini_result
.
stride
()
)
temp_tensor
=
infinicore_tensor_from_torch
(
torch_result_from_infini
)
temp_tensor
.
copy_
(
infini_result
)
return
torch_result_from_infini
def
rearrange_tensor
(
tensor
,
new_strides
):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
"""
import
torch
shape
=
tensor
.
shape
new_size
=
[
0
]
*
len
(
shape
)
left
=
0
right
=
0
for
i
in
range
(
len
(
shape
)):
if
new_strides
[
i
]
>=
0
:
new_size
[
i
]
=
(
shape
[
i
]
-
1
)
*
new_strides
[
i
]
+
1
right
+=
new_strides
[
i
]
*
(
shape
[
i
]
-
1
)
else
:
# TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise
ValueError
(
"Negative strides are not supported yet"
)
# Create a new tensor with zeros
new_tensor
=
torch
.
zeros
(
(
right
-
left
+
1
,),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
# Generate indices for original tensor based on original strides
indices
=
[
torch
.
arange
(
s
)
for
s
in
shape
]
mesh
=
torch
.
meshgrid
(
*
indices
,
indexing
=
"ij"
)
# Flatten indices for linear indexing
linear_indices
=
[
m
.
flatten
()
for
m
in
mesh
]
# Calculate new positions based on new strides
new_positions
=
sum
(
linear_indices
[
i
]
*
new_strides
[
i
]
for
i
in
range
(
len
(
shape
))
).
to
(
tensor
.
device
)
offset
=
-
left
new_positions
+=
offset
# Copy the original data to the new tensor
new_tensor
.
reshape
(
-
1
).
index_add_
(
0
,
new_positions
,
tensor
.
reshape
(
-
1
))
new_tensor
.
set_
(
new_tensor
.
untyped_storage
(),
offset
,
shape
,
tuple
(
new_strides
))
return
new_tensor
def
is_broadcast
(
strides
):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if
strides
is
None
:
return
False
return
any
(
s
==
0
for
s
in
strides
)
# =================================================================
# Type Checks (Moved here to avoid circular imports in check.py)
# =================================================================
def
is_integer_dtype
(
dtype
):
"""Check if dtype is integer type"""
return
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
,
torch
.
uint8
,
torch
.
bool
,
]
def
is_complex_dtype
(
dtype
):
"""Check if dtype is complex type"""
return
dtype
in
[
torch
.
complex64
,
torch
.
complex128
]
def
is_floating_dtype
(
dtype
):
"""Check if dtype is floating-point type"""
return
dtype
in
[
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
bfloat16
,
]
test/infinicore/ops/adaptive_max_pool2d.py
View file @
12cde8eb
...
...
@@ -7,6 +7,7 @@ import torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
CaseResult
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
...
...
@@ -76,7 +77,7 @@ class OpTest(BaseOperatorTest):
and
isinstance
(
test_case
.
inputs
[
0
],
TensorSpec
)
and
test_case
.
inputs
[
0
].
strides
is
not
None
):
return
Test
Result
(
return
Case
Result
(
success
=
False
,
return_code
=-
2
,
test_case
=
test_case
,
...
...
test/infinicore/ops/embedding.py
View file @
12cde8eb
...
...
@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import
torch
from
framework
import
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
from
framework.tensor
import
TensorInitializer
from
framework.utils
import
(
from
framework.utils
.tensor_utils
import
(
convert_infinicore_to_torch
,
infinicore_tensor_from_torch
,
to_torch_dtype
,
...
...
test/infinicore/ops/random_sample.py
View file @
12cde8eb
...
...
@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from
framework.
base
import
Test
Result
from
framework.utils
import
(
from
framework.
results
import
Case
Result
from
framework.utils
.tensor_utils
import
(
convert_infinicore_to_torch
,
infinicore_tensor_from_torch
,
)
...
...
@@ -268,8 +268,8 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
if
ic_idx
==
ref_idx
:
# Return a successful
Test
Result object
return
Test
Result
(
# Return a successful
Case
Result object
return
Case
Result
(
success
=
True
,
return_code
=
0
,
test_case
=
test_case
,
...
...
@@ -283,8 +283,8 @@ class OpTest(BaseOperatorTest):
logits_ic
=
logits_tensor
[
ic_idx
].
item
()
if
logits_ic
==
logits_ref
:
# Valid: different indices but same logits value
# Return a successful
Test
Result object
return
Test
Result
(
# Return a successful
Case
Result object
return
Case
Result
(
success
=
True
,
return_code
=
0
,
test_case
=
test_case
,
...
...
test/infinicore/ops/sort.py
View file @
12cde8eb
...
...
@@ -7,6 +7,7 @@ import torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
CaseResult
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
...
...
@@ -180,7 +181,7 @@ class OpTest(BaseOperatorTest):
and
isinstance
(
test_case
.
inputs
[
0
],
TensorSpec
)
and
test_case
.
inputs
[
0
].
strides
is
not
None
):
return
Test
Result
(
return
Case
Result
(
success
=
False
,
return_code
=-
2
,
test_case
=
test_case
,
...
...
@@ -193,7 +194,7 @@ class OpTest(BaseOperatorTest):
)
for
spec
in
output_specs
:
if
isinstance
(
spec
,
TensorSpec
)
and
spec
.
strides
is
not
None
:
return
Test
Result
(
return
Case
Result
(
success
=
False
,
return_code
=-
2
,
test_case
=
test_case
,
...
...
test/infinicore/ops/std.py
View file @
12cde8eb
...
...
@@ -7,6 +7,7 @@ import torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
CaseResult
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
...
...
@@ -122,7 +123,7 @@ class OpTest(BaseOperatorTest):
and
isinstance
(
test_case
.
inputs
[
0
],
TensorSpec
)
and
test_case
.
inputs
[
0
].
strides
is
not
None
):
return
Test
Result
(
return
Case
Result
(
success
=
False
,
return_code
=-
2
,
test_case
=
test_case
,
...
...
@@ -135,7 +136,7 @@ class OpTest(BaseOperatorTest):
and
isinstance
(
test_case
.
output_spec
,
TensorSpec
)
and
test_case
.
output_spec
.
strides
is
not
None
):
return
Test
Result
(
return
Case
Result
(
success
=
False
,
return_code
=-
2
,
test_case
=
test_case
,
...
...
test/infinicore/run.py
View file @
12cde8eb
import
os
import
sys
import
argparse
import
traceback
from
pathlib
import
Path
import
importlib.util
# Import components from the unified framework package
from
framework.executor
import
TestExecutor
from
framework.results
import
TestSummary
,
TestTiming
from
framework
import
get_hardware_args_group
,
add_common_test_args
def
find_ops_directory
(
location
=
None
):
"""
Find the ops directory by searching from location upwards.
Args:
location: Starting directory for search (default: current file's parent)
Returns:
Path: Path to ops directory or None if not found
"""
if
location
is
None
:
location
=
Path
(
__file__
).
parent
/
"ops"
ops_dir
=
location
.
resolve
()
if
ops_dir
.
exists
()
and
any
(
ops_dir
.
glob
(
"*.py"
)):
return
ops_dir
return
None
def
get_available_operators
(
ops_dir
):
"""
Get list of available operators from ops directory.
Args:
ops_dir: Path to ops directory
Returns:
List of operator names
"""
if
not
ops_dir
or
not
ops_dir
.
exists
():
return
[]
test_files
=
list
(
ops_dir
.
glob
(
"*.py"
))
current_script
=
Path
(
__file__
).
name
test_files
=
[
f
for
f
in
test_files
if
f
.
name
!=
current_script
]
operators
=
[]
for
test_file
in
test_files
:
try
:
with
open
(
test_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
f
.
read
()
if
"infinicore"
in
content
and
(
"BaseOperatorTest"
in
content
or
"GenericTestRunner"
in
content
):
operators
.
append
(
test_file
.
stem
)
except
:
continue
return
sorted
(
operators
)
def
import_operator_test
(
test_file_path
):
"""
Import an operator test module and return the test class instance.
Args:
test_file_path: Path to the test file
class
TestDiscoverer
:
def
__init__
(
self
,
ops_dir_path
=
None
):
self
.
ops_dir
=
self
.
_resolve_dir
(
ops_dir_path
)
def
_resolve_dir
(
self
,
path
):
if
path
:
p
=
Path
(
path
)
if
p
.
exists
():
return
p
# Default fallback logic: 'ops' directory under the parent of the current file's parent.
# Note: Since this file is in 'infinicore/', we look at parent.
# It is recommended to pass an explicit path in run.py.
fallback
=
Path
(
__file__
).
parent
/
"ops"
return
fallback
if
fallback
.
exists
()
else
None
def
get_available_operators
(
self
):
"""Returns a list of names of all available operators."""
if
not
self
.
ops_dir
:
return
[]
files
=
self
.
scan
()
return
sorted
([
f
.
stem
for
f
in
files
])
def
get_raw_python_files
(
self
):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if
not
self
.
ops_dir
or
not
self
.
ops_dir
.
exists
():
return
[]
files
=
list
(
self
.
ops_dir
.
glob
(
"*.py"
))
# Exclude run.py itself and __init__.py
return
[
f
.
name
for
f
in
files
if
f
.
name
!=
"run.py"
and
not
f
.
name
.
startswith
(
"__"
)
]
Returns:
tuple: (success, test_instance_or_error)
"""
try
:
# Create a unique module name
module_name
=
f
"op_test_
{
test_file_path
.
stem
}
"
# Load the module from file
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
test_file_path
)
if
spec
is
None
or
spec
.
loader
is
None
:
return
False
,
f
"Could not load module from
{
test_file_path
}
"
module
=
importlib
.
util
.
module_from_spec
(
spec
)
# Add the module to sys.modules
sys
.
modules
[
module_name
]
=
module
# Execute the module
spec
.
loader
.
exec_module
(
module
)
# Find the test class (usually named OpTest)
test_class
=
None
for
attr_name
in
dir
(
module
):
attr
=
getattr
(
module
,
attr_name
)
if
(
isinstance
(
attr
,
type
)
and
hasattr
(
attr
,
"__bases__"
)
and
any
(
"BaseOperatorTest"
in
str
(
base
)
for
base
in
attr
.
__bases__
)
):
test_class
=
attr
break
if
test_class
is
None
:
return
False
,
f
"No test class found in
{
test_file_path
}
"
# Create an instance
test_instance
=
test_class
()
return
True
,
test_instance
except
Exception
as
e
:
return
False
,
f
"Error importing
{
test_file_path
}
:
{
str
(
e
)
}
"
def
run_all_op_tests
(
ops_dir
=
None
,
specific_ops
=
None
,
bench
=
False
,
bench_mode
=
"both"
,
verbose
=
False
,
debug
=
False
,
):
"""
Run all operator test scripts in the ops directory using direct import.
def
scan
(
self
,
specific_ops
=
None
):
"""Scans and returns a list of Path objects that meet the criteria."""
if
not
self
.
ops_dir
or
not
self
.
ops_dir
.
exists
():
return
[]
Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
specific_ops (list, optional): List of specific operator names to test.
bench (bool): Whether benchmarking is enabled
bench_mode (str): Benchmark mode - "host", "device", or "both"
verbose (bool): Whether verbose mode is enabled
# 1. Find all .py files
files
=
list
(
self
.
ops_dir
.
glob
(
"*.py"
))
Returns:
dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
"""
if
ops_dir
is
None
:
ops_dir
=
find_ops_directory
()
else
:
ops_dir
=
Path
(
ops_dir
)
target_ops_set
=
set
(
specific_ops
)
if
specific_ops
else
None
if
not
ops_dir
or
not
ops_dir
.
exists
():
print
(
f
"Error: Ops directory '
{
ops_dir
}
' does not exist."
)
return
{}
# 2. Filter out non-test files (via content check)
valid_files
=
[]
for
f
in
files
:
# A. Basic Name Filtering
if
f
.
name
.
startswith
(
"_"
)
or
f
.
name
==
"run.py"
:
continue
print
(
f
"Looking for test files in:
{
ops_dir
}
"
)
# B. Specific Ops Filtering
if
target_ops_set
and
f
.
stem
not
in
target_ops_set
:
continue
# Find all Python test files
test_files
=
list
(
ops_dir
.
glob
(
"*.py"
))
# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if
self
.
_is_operator_test
(
f
):
valid_files
.
append
(
f
)
# Filter out this script itself and non-operator test files
current_script
=
Path
(
__file__
).
name
test_files
=
[
f
for
f
in
test_files
if
f
.
name
!=
current_script
]
return
valid_files
# Filter to include only files that look like operator tests
operator_test_files
=
[]
for
test_file
in
test_files
:
def
_is_operator_test
(
self
,
file_path
):
"""Checks if the file content contains operator test characteristics."""
try
:
with
open
(
test_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
f
.
read
()
# Look for characteristic patterns of operator tests
if
"infinicore"
in
content
and
(
return
"infinicore"
in
content
and
(
"BaseOperatorTest"
in
content
or
"GenericTestRunner"
in
content
):
operator_test_files
.
append
(
test_file
)
except
Exception
as
e
:
continue
# Filter for specific operators if requested
if
specific_ops
:
filtered_files
=
[]
for
test_file
in
operator_test_files
:
test_name
=
test_file
.
stem
.
lower
()
if
any
(
op
.
lower
()
==
test_name
for
op
in
specific_ops
):
filtered_files
.
append
(
test_file
)
operator_test_files
=
filtered_files
if
not
operator_test_files
:
print
(
f
"No operator test files found in
{
ops_dir
}
"
)
print
(
f
"Available Python files:
{
[
f
.
name
for
f
in
test_files
]
}
"
)
return
{}
print
(
f
"Found
{
len
(
operator_test_files
)
}
operator test files:"
)
for
test_file
in
operator_test_files
:
print
(
f
" -
{
test_file
.
name
}
"
)
results
=
{}
cumulative_timing
=
{
"total_torch_host_time"
:
0.0
,
"total_torch_device_time"
:
0.0
,
"total_infinicore_host_time"
:
0.0
,
"total_infinicore_device_time"
:
0.0
,
"operators_tested"
:
0
,
}
for
test_file
in
operator_test_files
:
test_name
=
test_file
.
stem
try
:
# Import and run the test directly
success
,
test_instance_or_error
=
import_operator_test
(
test_file
)
if
not
success
:
print
(
f
"💥
{
test_name
}
: ERROR -
{
test_instance_or_error
}
"
)
results
[
test_name
]
=
{
"success"
:
False
,
"return_code"
:
-
1
,
"torch_host_time"
:
0.0
,
"torch_device_time"
:
0.0
,
"infini_host_time"
:
0.0
,
"infini_device_time"
:
0.0
,
"error_message"
:
test_instance_or_error
,
"test_runner"
:
None
,
"stdout"
:
""
,
"stderr"
:
test_instance_or_error
,
}
continue
# Get the test runner class from the module
test_module
=
sys
.
modules
[
f
"op_test_
{
test_file
.
stem
}
"
]
if
not
hasattr
(
test_module
,
"GenericTestRunner"
):
print
(
f
"💥
{
test_name
}
: ERROR - No GenericTestRunner found"
)
results
[
test_name
]
=
{
"success"
:
False
,
"return_code"
:
-
1
,
"torch_host_time"
:
0.0
,
"torch_device_time"
:
0.0
,
"infini_host_time"
:
0.0
,
"infini_device_time"
:
0.0
,
"error_message"
:
"No GenericTestRunner found"
,
"test_runner"
:
None
,
"stdout"
:
""
,
"stderr"
:
"No GenericTestRunner found"
,
}
continue
# Create and run the test runner
test_runner_class
=
test_module
.
GenericTestRunner
runner_instance
=
test_runner_class
(
test_instance_or_error
.
__class__
)
# Temporarily redirect stdout to capture output
from
io
import
StringIO
stdout_capture
=
StringIO
()
stderr_capture
=
StringIO
()
old_stdout
=
sys
.
stdout
old_stderr
=
sys
.
stderr
sys
.
stdout
=
stdout_capture
sys
.
stderr
=
stderr_capture
try
:
# Run the test
test_success
,
test_runner
=
runner_instance
.
run
()
# Get captured output
stdout_output
=
stdout_capture
.
getvalue
()
stderr_output
=
stderr_capture
.
getvalue
()
# Restore stdout/stderr
sys
.
stdout
=
old_stdout
sys
.
stderr
=
old_stderr
# Print the captured output
if
stdout_output
:
print
(
stdout_output
.
rstrip
())
if
stderr_output
:
print
(
"
\n
STDERR:"
)
print
(
stderr_output
.
rstrip
())
# Analyze test results
test_results
=
test_runner
.
get_test_results
()
if
test_runner
else
[]
# Determine overall test status
if
test_success
:
return_code
=
0
status_icon
=
"✅"
status_text
=
"PASSED"
else
:
# Check if there are any failed tests
has_failures
=
any
(
result
.
return_code
==
-
1
for
result
in
test_results
)
has_partial
=
any
(
result
.
return_code
==
-
3
for
result
in
test_results
)
has_skipped
=
any
(
result
.
return_code
==
-
2
for
result
in
test_results
)
if
has_failures
:
return_code
=
-
1
status_icon
=
"❌"
status_text
=
"FAILED"
elif
has_partial
:
return_code
=
-
3
status_icon
=
"⚠️"
status_text
=
"PARTIAL"
elif
has_skipped
:
return_code
=
-
2
status_icon
=
"⏭️"
status_text
=
"SKIPPED"
else
:
return_code
=
-
1
status_icon
=
"❌"
status_text
=
"FAILED"
# Calculate timing for all four metrics
torch_host_time
=
sum
(
result
.
torch_host_time
for
result
in
test_results
)
torch_device_time
=
sum
(
result
.
torch_device_time
for
result
in
test_results
)
infini_host_time
=
sum
(
result
.
infini_host_time
for
result
in
test_results
)
infini_device_time
=
sum
(
result
.
infini_device_time
for
result
in
test_results
)
results
[
test_name
]
=
{
"success"
:
test_success
,
"return_code"
:
return_code
,
"torch_host_time"
:
torch_host_time
,
"torch_device_time"
:
torch_device_time
,
"infini_host_time"
:
infini_host_time
,
"infini_device_time"
:
infini_device_time
,
"error_message"
:
""
,
"test_runner"
:
test_runner
,
"stdout"
:
stdout_output
,
"stderr"
:
stderr_output
,
}
print
(
f
"
{
status_icon
}
{
test_name
}
:
{
status_text
}
(return code:
{
return_code
}
)"
)
# Extract benchmark timing if in bench mode
if
bench
and
test_success
and
return_code
==
0
:
cumulative_timing
[
"total_torch_host_time"
]
+=
torch_host_time
cumulative_timing
[
"total_torch_device_time"
]
+=
torch_device_time
cumulative_timing
[
"total_infinicore_host_time"
]
+=
infini_host_time
cumulative_timing
[
"total_infinicore_device_time"
]
+=
infini_device_time
cumulative_timing
[
"operators_tested"
]
+=
1
except
Exception
as
e
:
# Restore stdout/stderr in case of exception
sys
.
stdout
=
old_stdout
sys
.
stderr
=
old_stderr
raise
e
# In verbose mode, stop execution on first failure
if
verbose
and
not
test_success
and
return_code
!=
0
:
break
except
Exception
as
e
:
print
(
f
"💥
{
test_name
}
: ERROR -
{
str
(
e
)
}
"
)
results
[
test_name
]
=
{
"success"
:
False
,
"return_code"
:
-
1
,
"torch_host_time"
:
0.0
,
"torch_device_time"
:
0.0
,
"infini_host_time"
:
0.0
,
"infini_device_time"
:
0.0
,
"error_message"
:
str
(
e
),
"test_runner"
:
None
,
"stdout"
:
""
,
"stderr"
:
str
(
e
),
}
# In verbose mode, stop execution on any exception
if
verbose
:
print
(
f
"
\n
{
'!'
*
60
}
"
)
print
(
f
"VERBOSE MODE: Stopping execution due to exception in
{
test_name
}
"
)
print
(
f
"
{
'!'
*
60
}
"
)
break
if
debug
:
traceback
.
print_exc
()
break
return
results
,
cumulative_timing
def
print_summary
(
results
,
verbose
=
False
,
total_expected_tests
=
0
,
cumulative_timing
=
None
,
bench_mode
=
"both"
,
):
"""Print a comprehensive summary of test results including benchmark data."""
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"CUMULATIVE TEST SUMMARY"
)
print
(
f
"
{
'='
*
80
}
"
)
if
not
results
:
print
(
"No tests were run."
)
return
False
# Count different types of results
passed
=
0
failed
=
0
skipped
=
0
partial
=
0
passed_operators
=
[]
# Store passed operator names
failed_operators
=
[]
# Store failed operator names
skipped_operators
=
[]
# Store skipped operator names
partial_operators
=
[]
# Store partial operator names
for
test_name
,
result_data
in
results
.
items
():
return_code
=
result_data
[
"return_code"
]
if
return_code
==
0
:
passed
+=
1
passed_operators
.
append
(
test_name
)
elif
return_code
==
-
2
:
# Special code for skipped tests
skipped
+=
1
skipped_operators
.
append
(
test_name
)
elif
return_code
==
-
3
:
# Special code for partial tests
partial
+=
1
partial_operators
.
append
(
test_name
)
else
:
failed
+=
1
failed_operators
.
append
(
test_name
)
total
=
len
(
results
)
print
(
f
"Total tests run:
{
total
}
"
)
if
total_expected_tests
>
0
and
total
<
total_expected_tests
:
print
(
f
"Total tests expected:
{
total_expected_tests
}
"
)
print
(
f
"Tests not executed:
{
total_expected_tests
-
total
}
"
)
print
(
f
"Passed:
{
passed
}
"
)
print
(
f
"Failed:
{
failed
}
"
)
if
skipped
>
0
:
print
(
f
"Skipped:
{
skipped
}
"
)
if
partial
>
0
:
print
(
f
"Partial:
{
partial
}
"
)
# Print benchmark summary if cumulative_timing data is available
if
cumulative_timing
and
cumulative_timing
[
"operators_tested"
]
>
0
:
print
(
f
"
{
'-'
*
40
}
"
)
print
(
"BENCHMARK SUMMARY:"
)
print
(
f
" Operators Tested:
{
cumulative_timing
[
'operators_tested'
]
}
"
)
# Display timing based on bench_mode
if
bench_mode
in
[
"host"
,
"both"
]:
print
(
f
" PyTorch Host Total Time:
{
cumulative_timing
[
'total_torch_host_time'
]:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Host Total Time:
{
cumulative_timing
[
'total_infinicore_host_time'
]:
12.3
f
}
ms"
)
if
bench_mode
in
[
"device"
,
"both"
]:
print
(
f
" PyTorch Device Total Time:
{
cumulative_timing
[
'total_torch_device_time'
]:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Device Total Time:
{
cumulative_timing
[
'total_infinicore_device_time'
]:
12.3
f
}
ms"
)
print
(
f
"
{
'-'
*
40
}
"
)
# Display passed operators
if
passed_operators
:
print
(
f
"
\n
✅ PASSED OPERATORS (
{
len
(
passed_operators
)
}
):"
)
# Display operators in groups of 10 per line
for
i
in
range
(
0
,
len
(
passed_operators
),
10
):
line_ops
=
passed_operators
[
i
:
i
+
10
]
print
(
" "
+
", "
.
join
(
line_ops
))
else
:
print
(
f
"
\n
✅ PASSED OPERATORS: None"
)
# Display failed operators (if any)
if
failed_operators
:
print
(
f
"
\n
❌ FAILED OPERATORS (
{
len
(
failed_operators
)
}
):"
)
for
i
in
range
(
0
,
len
(
failed_operators
),
10
):
line_ops
=
failed_operators
[
i
:
i
+
10
]
print
(
" "
+
", "
.
join
(
line_ops
))
# Display skipped operators (if any)
if
skipped_operators
:
print
(
f
"
\n
⏭️ SKIPPED OPERATORS (
{
len
(
skipped_operators
)
}
):"
)
for
i
in
range
(
0
,
len
(
skipped_operators
),
10
):
line_ops
=
skipped_operators
[
i
:
i
+
10
]
print
(
" "
+
", "
.
join
(
line_ops
))
# Display partial operators (if any)
if
partial_operators
:
print
(
f
"
\n
⚠️ PARTIAL OPERATORS (
{
len
(
partial_operators
)
}
):"
)
for
i
in
range
(
0
,
len
(
partial_operators
),
10
):
line_ops
=
partial_operators
[
i
:
i
+
10
]
print
(
" "
+
", "
.
join
(
line_ops
))
if
total
>
0
:
# Calculate success rate based on actual executed tests
executed_tests
=
passed
+
failed
+
partial
if
executed_tests
>
0
:
success_rate
=
passed
/
executed_tests
*
100
print
(
f
"
\n
Success rate:
{
success_rate
:.
1
f
}
%"
)
if
verbose
and
total
<
total_expected_tests
:
print
(
f
"
\n
💡 Verbose mode: Execution stopped after first failure"
)
print
(
f
"
{
total_expected_tests
-
total
}
tests were not executed"
)
if
failed
==
0
:
if
skipped
>
0
or
partial
>
0
:
print
(
f
"
\n
⚠️ Tests completed with some operators not implemented"
)
print
(
f
" -
{
skipped
}
tests skipped (both operators not implemented)"
)
print
(
f
" -
{
partial
}
tests partial (one operator not implemented)"
)
else
:
print
(
f
"
\n
🎉 All tests passed!"
)
return
True
else
:
print
(
f
"
\n
❌
{
failed
}
tests failed"
)
return
False
def
list_available_tests
(
ops_dir
=
None
):
"""List all available operator test files."""
if
ops_dir
is
None
:
ops_dir
=
find_ops_directory
()
else
:
ops_dir
=
Path
(
ops_dir
)
if
not
ops_dir
or
not
ops_dir
.
exists
():
print
(
f
"Error: Ops directory '
{
ops_dir
}
' does not exist."
)
return
operators
=
get_available_operators
(
ops_dir
)
if
operators
:
print
(
f
"Available operator test files in
{
ops_dir
}
:"
)
for
operator
in
operators
:
print
(
f
" -
{
operator
}
"
)
print
(
f
"
\n
Total:
{
len
(
operators
)
}
operators"
)
else
:
print
(
f
"No operator test files found in
{
ops_dir
}
"
)
# Show available Python files for debugging
test_files
=
list
(
ops_dir
.
glob
(
"*.py"
))
current_script
=
Path
(
__file__
).
name
test_files
=
[
f
for
f
in
test_files
if
f
.
name
!=
current_script
]
if
test_files
:
print
(
f
"Available Python files:
{
[
f
.
name
for
f
in
test_files
]
}
"
)
except
:
return
False
def
generate_help_epilog
(
ops_dir
):
def
generate_help_epilog
(
ops_dir
=
None
):
"""
Generate dynamic help epilog with available operators and hardware platforms.
Args:
ops_dir: Path to ops directory
Returns:
str: Formatted help text
Generate dynamic help epilog containing available operators and hardware platforms.
Maintains the original output format for backward compatibility.
"""
# Get available operators
operators
=
get_available_operators
(
ops_dir
)
# === Adapter: Use TestDiscoverer to get operator list ===
# Temporarily instantiate a Discoverer just to fetch the list
discoverer
=
TestDiscoverer
(
ops_dir
)
operators
=
discoverer
.
get_available_operators
()
# Build epilog text
# Build epilog text
(fully replicating original logic)
epilog_parts
=
[]
# Examples section
...
...
@@ -628,17 +162,12 @@ def generate_help_epilog(ops_dir):
def
main
():
"""Main entry point with comprehensive command line argument parsing."""
# First, find ops directory for dynamic help generation
ops_dir
=
find_ops_directory
()
"""Main entry point for the InfiniCore Operator Test Runner."""
parser
=
argparse
.
ArgumentParser
(
description
=
"Run InfiniCore operator tests across multiple hardware platforms"
,
formatter_class
=
argparse
.
RawDescriptionHelpFormatter
,
epilog
=
generate_help_epilog
(
ops_dir
),
epilog
=
generate_help_epilog
(),
)
# Core options
parser
.
add_argument
(
"--ops-dir"
,
type
=
str
,
help
=
"Path to the ops directory (default: auto-detect)"
)
...
...
@@ -650,118 +179,97 @@ def main():
action
=
"store_true"
,
help
=
"List all available test files without running them"
,
)
# Call common method to add shared arguments (bench, debug, verbose, save...)
add_common_test_args
(
parser
)
# Add common test arguments (including --save, --bench, etc.)
add_common_test_args
(
parser
)
get_hardware_args_group
(
parser
)
# Parse known args first, leave the rest for the test scripts
args
,
unknown_args
=
parser
.
parse_known_args
()
# Handle list command
if
args
.
list
:
list_available_tests
(
args
.
ops_dir
)
return
# Auto-detect ops directory if not provided
if
args
.
ops_dir
is
None
:
ops_dir
=
find_ops_directory
()
if
not
ops_dir
:
print
(
"Error: Could not auto-detect ops directory. Please specify with --ops-dir"
)
sys
.
exit
(
1
)
else
:
ops_dir
=
Path
(
args
.
ops_dir
)
if
not
ops_dir
.
exists
():
print
(
f
"Error: Ops directory '
{
ops_dir
}
' does not exist."
)
sys
.
exit
(
1
)
# Show what extra arguments will be passed
if
unknown_args
:
print
(
f
"Passing extra arguments to test scripts:
{
unknown_args
}
"
)
# Get available operators for display
available_operators
=
get_available_operators
(
ops_dir
)
print
(
f
"InfiniCore Operator Test Runner"
)
print
(
f
"Operating directory:
{
ops_dir
}
"
)
print
(
f
"Available operators:
{
len
(
available_operators
)
}
"
)
# 1. Discovery
discoverer
=
TestDiscoverer
(
args
.
ops_dir
)
if
args
.
list
:
print
(
"Available operators:"
,
discoverer
.
get_available_operators
())
return
if
args
.
verbose
:
print
(
f
"Verbose mode: ENABLED (will stop on first error with full traceback)"
)
if
args
.
bench
:
bench_mode
=
args
.
bench
if
args
.
bench
!=
"both"
else
"both"
print
(
f
"Benchmark mode:
{
bench_mode
.
upper
()
}
timing"
)
print
(
f
"Benchmark mode:
{
args
.
bench
.
upper
()
}
timing"
)
target_ops
=
None
if
args
.
ops
:
# Validate requested operators
valid_ops
=
[]
invalid_ops
=
[]
for
op
in
args
.
ops
:
if
op
in
available_operators
:
valid_ops
.
append
(
op
)
else
:
invalid_ops
.
append
(
op
)
# Get all available operator names
available_ops
=
set
(
discoverer
.
get_available_operators
())
requested_ops
=
set
(
args
.
ops
)
# Classify using set operations
valid_ops
=
list
(
requested_ops
&
available_ops
)
# Intersection: Valid ops
invalid_ops
=
list
(
requested_ops
-
available_ops
)
# Difference: Invalid ops
# Warn if there are invalid operators
if
invalid_ops
:
print
(
f
"Warning: Unknown operators:
{
', '
.
join
(
invalid_ops
)
}
"
)
print
(
f
"Available operators:
{
', '
.
join
(
available_operators
)
}
"
)
print
(
f
"⚠️ Warning: The following requested operators were not found:"
)
print
(
f
"
{
', '
.
join
(
invalid_ops
)
}
"
)
print
(
f
" (Use --list to see available operators)"
)
if
valid_ops
:
print
(
f
"Testing operators:
{
', '
.
join
(
valid_ops
)
}
"
)
total_expected_tests
=
len
(
valid_ops
)
else
:
print
(
"No valid operators specified. Running all available tests."
)
total_expected_tests
=
len
(
available_operators
)
else
:
print
(
"Testing all available operators"
)
total_expected_tests
=
len
(
available_operators
)
print
()
# Run all tests
results
,
cumulative_timing
=
run_all_op_tests
(
ops_dir
=
ops_dir
,
specific_ops
=
args
.
ops
,
bench
=
bool
(
args
.
bench
),
bench_mode
=
args
.
bench
if
args
.
bench
else
"both"
,
verbose
=
args
.
verbose
,
debug
=
args
.
debug
,
)
if
not
valid_ops
:
# Case A: User input provided, but ALL were invalid.
print
(
f
"⚠️ No valid operators remained from your list."
)
print
(
f
"🔄 Fallback: Proceeding to run ALL available tests..."
)
# Print summary and exit with appropriate code
all_passed
=
print_summary
(
else
:
# Case B: At least some valid operators found.
print
(
f
"🎯 Targeted operators:
{
', '
.
join
(
valid_ops
)
}
"
)
target_ops
=
valid_ops
test_files
=
discoverer
.
scan
(
target_ops
)
if
not
test_files
:
print
(
"No tests found."
)
sys
.
exit
(
0
)
# 2. Preparation
executor
=
TestExecutor
()
cumulative_timing
=
TestTiming
()
test_summary
=
TestSummary
(
args
.
verbose
,
args
.
bench
)
results
=
[]
test_summary
.
print_header
(
discoverer
.
ops_dir
,
len
(
test_files
))
# 3. Execution Loop
for
f
in
test_files
:
result
=
executor
.
execute
(
f
)
results
.
append
(
result
)
# Real-time reporting and printing of stdout
test_summary
.
print_live_result
(
result
)
# Accumulate timing
if
result
.
success
:
cumulative_timing
.
torch_host
+=
result
.
timing
.
torch_host
cumulative_timing
.
infini_host
+=
result
.
timing
.
infini_host
cumulative_timing
.
torch_device
+=
result
.
timing
.
torch_device
cumulative_timing
.
infini_device
+=
result
.
timing
.
infini_device
cumulative_timing
.
operators_tested
+=
1
# Fail fast in verbose mode
if
args
.
verbose
and
not
result
.
success
:
print
(
"
\n
Stopping due to failure in verbose mode."
)
break
# 4. Final Report & Save
all_passed
=
test_summary
.
print_summary
(
results
,
args
.
verbose
,
total_expected_tests
,
cumulative_timing
,
bench_mode
=
args
.
bench
if
args
.
bench
else
"both"
,
)
# Check if there were any tests with missing implementations
has_missing_implementations
=
any
(
result_data
[
"return_code"
]
in
[
-
2
,
-
3
]
for
result_data
in
results
.
values
()
cumulative_timing
if
args
.
bench
else
None
,
ops_dir
=
discoverer
.
ops_dir
,
total_expected
=
len
(
test_files
),
)
if
all_passed
and
has_missing_implementations
:
print
(
f
"
\n
⚠️ Note: Some operators are not fully implemented"
)
print
(
f
" Run individual tests for details on missing implementations"
)
if
args
.
verbose
and
not
all_passed
:
print
(
f
"
\n
💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
failed_ops
=
[
name
for
name
,
result_data
in
results
.
items
()
if
result_data
[
"return_code"
]
==
-
1
]
for
op
in
failed_ops
[:
3
]:
# Show first 3 failed operators
print
(
f
" python
{
ops_dir
/
(
op
+
'.py'
)
}
--verbose"
)
sys
.
exit
(
0
if
all_passed
else
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment