Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5aa850af
Commit
5aa850af
authored
Dec 18, 2025
by
baominghelly
Browse files
func behavior fix && rename class name
parent
0d58c820
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
416 additions
and
181 deletions
+416
-181
test/infinicore/framework/__init__.py
test/infinicore/framework/__init__.py
+8
-1
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+2
-1
test/infinicore/framework/datatypes.py
test/infinicore/framework/datatypes.py
+0
-34
test/infinicore/framework/driver.py
test/infinicore/framework/driver.py
+18
-11
test/infinicore/framework/entities.py
test/infinicore/framework/entities.py
+0
-15
test/infinicore/framework/loader.py
test/infinicore/framework/loader.py
+22
-4
test/infinicore/framework/printer.py
test/infinicore/framework/printer.py
+180
-0
test/infinicore/framework/reporter.py
test/infinicore/framework/reporter.py
+1
-103
test/infinicore/framework/types.py
test/infinicore/framework/types.py
+52
-0
test/infinicore/run.py
test/infinicore/run.py
+133
-12
No files found.
test/infinicore/framework/__init__.py
View file @
5aa850af
from
.base
import
TestConfig
,
TestRunner
,
BaseOperatorTest
from
.
test_case
import
TestCase
,
TestResult
from
.
entities
import
TestCase
from
.benchmark
import
BenchmarkUtils
,
BenchmarkResult
from
.config
import
(
add_common_test_args
,
...
...
@@ -11,6 +11,9 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype
from
.devices
import
InfiniDeviceEnum
,
InfiniDeviceNames
,
torch_device_map
from
.runner
import
GenericTestRunner
from
.tensor
import
TensorSpec
,
TensorInitializer
from
.types
import
TestTiming
,
OperatorTestResult
,
TestResult
from
.driver
import
TestDriver
from
.printer
import
ConsolePrinter
from
.utils
import
(
compare_results
,
create_test_comparator
,
...
...
@@ -38,6 +41,10 @@ __all__ = [
"TestResult"
,
"TestRunner"
,
"TestReporter"
,
"TestTiming"
,
"OperatorTestResult"
,
"TestDriver"
,
"ConsolePrinter"
,
# Core functions
"add_common_test_args"
,
"compare_results"
,
...
...
test/infinicore/framework/base.py
View file @
5aa850af
...
...
@@ -8,7 +8,8 @@ import infinicore
import
traceback
from
abc
import
ABC
,
abstractmethod
from
.test_case
import
TestCase
,
TestResult
from
.entities
import
TestCase
from
.types
import
TestResult
from
.datatypes
import
to_torch_dtype
,
to_infinicore_dtype
from
.devices
import
InfiniDeviceNames
,
torch_device_map
from
.tensor
import
TensorSpec
,
TensorInitializer
...
...
test/infinicore/framework/datatypes.py
View file @
5aa850af
...
...
@@ -60,37 +60,3 @@ def to_infinicore_dtype(torch_dtype):
return
infinicore
.
complex128
else
:
raise
ValueError
(
f
"Unsupported torch dtype:
{
torch_dtype
}
"
)
@
dataclass
class
TestTiming
:
"""Stores performance testing timing metrics."""
torch_host
:
float
=
0.0
torch_device
:
float
=
0.0
infini_host
:
float
=
0.0
infini_device
:
float
=
0.0
operators_tested
:
int
=
0
@
dataclass
class
SingleTestResult
:
"""Stores the execution results of a single test file."""
name
:
str
success
:
bool
=
False
return_code
:
int
=
-
1
error_message
:
str
=
""
stdout
:
str
=
""
stderr
:
str
=
""
timing
:
TestTiming
=
field
(
default_factory
=
TestTiming
)
@
property
def
status_icon
(
self
):
if
self
.
return_code
==
0
:
return
"✅"
if
self
.
return_code
==
-
2
:
return
"⏭️"
if
self
.
return_code
==
-
3
:
return
"⚠️"
return
"❌"
@
property
def
status_text
(
self
):
if
self
.
return_code
==
0
:
return
"PASSED"
if
self
.
return_code
==
-
2
:
return
"SKIPPED"
if
self
.
return_code
==
-
3
:
return
"PARTIAL"
return
"FAILED"
test/infinicore/framework/
executo
r.py
→
test/infinicore/framework/
drive
r.py
View file @
5aa850af
...
...
@@ -2,7 +2,7 @@ import sys
import
importlib.util
from
io
import
StringIO
from
contextlib
import
contextmanager
from
.
data
types
import
Single
TestResult
,
TestTiming
from
.types
import
Operator
TestResult
,
TestTiming
@
contextmanager
def
capture_output
():
...
...
@@ -15,9 +15,9 @@ def capture_output():
finally
:
sys
.
stdout
,
sys
.
stderr
=
old_out
,
old_err
class
SingleTestExecuto
r
:
def
run
(
self
,
file_path
)
->
Single
TestResult
:
result
=
Single
TestResult
(
name
=
file_path
.
stem
)
class
TestDrive
r
:
def
drive
(
self
,
file_path
)
->
Operator
TestResult
:
result
=
Operator
TestResult
(
name
=
file_path
.
stem
)
try
:
# 1. Dynamically import the module
...
...
@@ -79,15 +79,22 @@ class SingleTestExecutor:
def
_analyze_return_code
(
self
,
result
,
test_results
):
# Logic consistent with original code: determine if all passed, partially passed, or skipped
if
not
result
.
success
:
result
.
return_code
=
-
1
if
result
.
success
:
result
.
return_code
=
0
return
codes
=
[
r
.
return_code
for
r
in
test_results
]
if
-
1
in
codes
:
result
.
return_code
=
-
1
elif
-
3
in
codes
:
result
.
return_code
=
-
3
elif
-
2
in
codes
:
result
.
return_code
=
-
2
else
:
result
.
return_code
=
0
has_failures
=
any
(
r
.
return_code
==
-
1
for
r
in
test_results
)
has_partial
=
any
(
r
.
return_code
==
-
3
for
r
in
test_results
)
has_skipped
=
any
(
r
.
return_code
==
-
2
for
r
in
test_results
)
if
has_failures
:
result
.
return_code
=
-
1
elif
has_partial
:
result
.
return_code
=
-
3
elif
has_skipped
:
result
.
return_code
=
-
2
else
:
result
.
return_code
=
-
1
def
_extract_timing
(
self
,
result
,
test_results
):
# Accumulate timing
...
...
test/infinicore/framework/
test_case
.py
→
test/infinicore/framework/
entities
.py
View file @
5aa850af
...
...
@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from
.tensor
import
TensorSpec
@
dataclass
class
TestResult
:
"""Test result data structure"""
success
:
bool
return_code
:
int
# 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time
:
float
=
0.0
torch_device_time
:
float
=
0.0
infini_host_time
:
float
=
0.0
infini_device_time
:
float
=
0.0
error_message
:
str
=
""
test_case
:
Any
=
None
device
:
Any
=
None
class
TestCase
:
"""Test case with all configuration included"""
...
...
test/infinicore/framework/loader.py
View file @
5aa850af
...
...
@@ -21,6 +21,18 @@ class TestDiscoverer:
files
=
self
.
scan
()
return
sorted
([
f
.
stem
for
f
in
files
])
def
get_raw_python_files
(
self
):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if
not
self
.
ops_dir
or
not
self
.
ops_dir
.
exists
():
return
[]
files
=
list
(
self
.
ops_dir
.
glob
(
"*.py"
))
# Exclude run.py itself and __init__.py
return
[
f
.
name
for
f
in
files
if
f
.
name
!=
"run.py"
and
not
f
.
name
.
startswith
(
"__"
)]
def
scan
(
self
,
specific_ops
=
None
):
"""Scans and returns a list of Path objects that meet the criteria."""
if
not
self
.
ops_dir
or
not
self
.
ops_dir
.
exists
():
...
...
@@ -29,18 +41,24 @@ class TestDiscoverer:
# 1. Find all .py files
files
=
list
(
self
.
ops_dir
.
glob
(
"*.py"
))
target_ops_set
=
set
(
specific_ops
)
if
specific_ops
else
None
# 2. Filter out non-test files (via content check)
valid_files
=
[]
for
f
in
files
:
# A. Basic Name Filtering
if
f
.
name
.
startswith
(
"_"
)
or
f
.
name
==
"run.py"
:
continue
# B. Specific Ops Filtering
if
target_ops_set
and
f
.
stem
not
in
target_ops_set
:
continue
# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if
self
.
_is_operator_test
(
f
):
valid_files
.
append
(
f
)
# 3. If specific operators are specified, filter them
if
specific_ops
:
return
[
f
for
f
in
valid_files
if
f
.
stem
in
specific_ops
]
return
valid_files
def
_is_operator_test
(
self
,
file_path
):
...
...
test/infinicore/framework/printer.py
0 → 100644
View file @
5aa850af
# lib/printer.py
import
sys
from
.types
import
OperatorTestResult
,
TestTiming
class
ConsolePrinter
:
"""
Handles all console output logic.
Acts as the 'View' in the application structure.
"""
def
list_tests
(
self
,
discoverer
):
"""
Intelligently list available tests.
If no valid operators are found, it falls back to listing raw Python files
to assist with debugging (e.g., typos in class inheritance).
"""
ops_dir
=
discoverer
.
ops_dir
operators
=
discoverer
.
get_available_operators
()
if
operators
:
print
(
f
"Available operator test files in
{
ops_dir
}
:"
)
for
operator
in
operators
:
print
(
f
" -
{
operator
}
"
)
print
(
f
"
\n
Total:
{
len
(
operators
)
}
operators"
)
else
:
print
(
f
"No valid operator tests found in
{
ops_dir
}
"
)
# === Fallback Debug Logic ===
raw_files
=
discoverer
.
get_raw_python_files
()
if
raw_files
:
print
(
f
"
\n
💡 Debug Hint: Found the following Python files (but they are not valid tests):"
)
print
(
f
"
{
raw_files
}
"
)
print
(
" (Ensure they inherit from 'BaseOperatorTest' and contain 'infinicore')"
)
def
print_header
(
self
,
ops_dir
,
count
):
print
(
f
"InfiniCore Operator Test Runner"
)
print
(
f
"Directory:
{
ops_dir
}
"
)
print
(
f
"Tests found:
{
count
}
\n
"
)
def
print_live_result
(
self
,
result
,
verbose
=
False
):
"""Print single-line result in real-time."""
print
(
f
"
{
result
.
status_icon
}
{
result
.
name
}
:
{
result
.
status_text
}
(code:
{
result
.
return_code
}
)"
)
# Only print details if verbose or if the test failed/had output
if
result
.
stdout
:
print
(
result
.
stdout
.
rstrip
())
if
result
.
stderr
:
print
(
"
\n
STDERR:"
,
result
.
stderr
.
rstrip
())
if
result
.
error_message
:
print
(
f
"💥 Error:
{
result
.
error_message
}
"
)
if
result
.
stdout
or
result
.
stderr
or
verbose
:
print
(
"-"
*
40
)
def
print_summary
(
self
,
results
,
cumulative_timing
,
ops_dir
,
total_expected
=
0
,
verbose
=
False
,
bench_mode
=
"both"
):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print
(
f
"
\n
{
'='
*
80
}
\n
CUMULATIVE TEST SUMMARY
\n
{
'='
*
80
}
"
)
passed
=
[
r
for
r
in
results
if
r
.
return_code
==
0
]
failed
=
[
r
for
r
in
results
if
r
.
return_code
==
-
1
]
skipped
=
[
r
for
r
in
results
if
r
.
return_code
==
-
2
]
partial
=
[
r
for
r
in
results
if
r
.
return_code
==
-
3
]
total
=
len
(
results
)
print
(
f
"Total tests run:
{
total
}
"
)
if
total_expected
>
0
and
total
<
total_expected
:
print
(
f
"Total tests expected:
{
total_expected
}
"
)
print
(
f
"Tests not executed:
{
total_expected
-
total
}
"
)
print
(
f
"Passed:
{
len
(
passed
)
}
"
)
print
(
f
"Failed:
{
len
(
failed
)
}
"
)
if
skipped
:
print
(
f
"Skipped:
{
len
(
skipped
)
}
"
)
if
partial
:
print
(
f
"Partial:
{
len
(
partial
)
}
"
)
# 1. Print Benchmark data
if
cumulative_timing
:
# Call the internal helper method
self
.
_print_timing
(
cumulative_timing
,
bench_mode
=
bench_mode
)
# 2. Print Detailed Lists
# PASSED
if
passed
:
self
.
_print_op_list
(
"✅ PASSED OPERATORS"
,
passed
)
else
:
print
(
f
"
\n
✅ PASSED OPERATORS: None"
)
# FAILED
if
failed
:
self
.
_print_op_list
(
"❌ FAILED OPERATORS"
,
failed
)
# SKIPPED
if
skipped
:
self
.
_print_op_list
(
"⏭️ SKIPPED OPERATORS"
,
skipped
)
# PARTIAL
if
partial
:
self
.
_print_op_list
(
"⚠️ PARTIAL IMPLEMENTATIONS"
,
partial
)
# 3. Restore Success Rate
if
total
>
0
:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests
=
total
-
len
(
skipped
)
if
executed_tests
>
0
:
success_rate
=
len
(
passed
)
/
executed_tests
*
100
print
(
f
"
\n
Success rate:
{
success_rate
:.
1
f
}
%"
)
if
not
failed
:
if
skipped
or
partial
:
print
(
f
"
\n
⚠️ Tests completed with some operators not fully implemented"
)
else
:
print
(
f
"
\n
🎉 All tests passed!"
)
else
:
print
(
f
"
\n
❌
{
len
(
failed
)
}
tests failed"
)
if
not
failed
and
(
skipped
or
partial
):
print
(
f
"
\n
⚠️ Note: Some operators are not fully implemented"
)
print
(
f
" Run individual tests for details on missing implementations"
)
if
verbose
and
failed
:
print
(
f
"
\n
💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
# Show first 3 failed operators to avoid spamming
for
r
in
failed
[:
3
]:
# Construct file path: ops_dir / filename.py
file_path
=
ops_dir
/
(
r
.
name
+
".py"
)
print
(
f
" python
{
file_path
}
--verbose"
)
if
len
(
failed
)
>
3
:
print
(
f
" ... (and
{
len
(
failed
)
-
3
}
others)"
)
return
len
(
failed
)
==
0
# --- Internal Helpers ---
def
_print_op_list
(
self
,
title
,
result_list
):
"""Helper to print a formatted list of operator names."""
print
(
f
"
\n
{
title
}
(
{
len
(
result_list
)
}
):"
)
names
=
[
r
.
name
for
r
in
result_list
]
# Group by 10 per line
for
i
in
range
(
0
,
len
(
names
),
10
):
print
(
" "
+
", "
.
join
(
names
[
i
:
i
+
10
]))
def
_print_timing
(
self
,
t
,
bench_mode
=
"both"
):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print
(
f
"
{
'-'
*
40
}
"
)
# Restore Operators Tested field using the dataclass field
if
hasattr
(
t
,
'operators_tested'
)
and
t
.
operators_tested
>
0
:
print
(
f
"BENCHMARK SUMMARY:"
)
print
(
f
" Operators Tested:
{
t
.
operators_tested
}
"
)
# Restore detailed Host/Device distinction
if
bench_mode
in
[
"host"
,
"both"
]:
print
(
f
" PyTorch Host Total Time:
{
t
.
torch_host
:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Host Total Time:
{
t
.
infini_host
:
12.3
f
}
ms"
)
if
bench_mode
in
[
"device"
,
"both"
]:
print
(
f
" PyTorch Device Total Time:
{
t
.
torch_device
:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Device Total Time:
{
t
.
infini_device
:
12.3
f
}
ms"
)
print
(
f
"
{
'-'
*
40
}
"
)
test/infinicore/framework/reporter.py
View file @
5aa850af
...
...
@@ -230,109 +230,7 @@ class TestReporter:
import
traceback
;
traceback
.
print_exc
()
print
(
f
" ❌ Save failed:
{
e
}
"
)
@
staticmethod
def
print_header
(
ops_dir
,
count
):
print
(
f
"InfiniCore Operator Test Runner"
)
print
(
f
"Directory:
{
ops_dir
}
"
)
print
(
f
"Tests found:
{
count
}
\n
"
)
@
staticmethod
def
print_live_result
(
result
,
verbose
=
False
):
"""Print single-line result in real-time."""
print
(
f
"
{
result
.
status_icon
}
{
result
.
name
}
:
{
result
.
status_text
}
(code:
{
result
.
return_code
}
)"
)
if
result
.
stdout
:
print
(
result
.
stdout
.
rstrip
())
if
result
.
stderr
:
print
(
"
\n
STDERR:"
,
result
.
stderr
.
rstrip
())
if
result
.
error_message
:
print
(
f
"💥 Error:
{
result
.
error_message
}
"
)
if
result
.
stdout
or
result
.
stderr
or
verbose
:
print
(
"-"
*
40
)
@
staticmethod
def
print_summary
(
results
,
cumulative_timing
,
total_expected
=
0
):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print
(
f
"
\n
{
'='
*
80
}
\n
CUMULATIVE TEST SUMMARY
\n
{
'='
*
80
}
"
)
passed
=
[
r
for
r
in
results
if
r
.
return_code
==
0
]
failed
=
[
r
for
r
in
results
if
r
.
return_code
==
-
1
]
skipped
=
[
r
for
r
in
results
if
r
.
return_code
==
-
2
]
partial
=
[
r
for
r
in
results
if
r
.
return_code
==
-
3
]
total
=
len
(
results
)
print
(
f
"Total tests run:
{
total
}
"
)
print
(
f
"Passed:
{
len
(
passed
)
}
"
)
print
(
f
"Failed:
{
len
(
failed
)
}
"
)
if
skipped
:
print
(
f
"Skipped:
{
len
(
skipped
)
}
"
)
if
partial
:
print
(
f
"Partial:
{
len
(
partial
)
}
"
)
# 1. Print Benchmark data
if
cumulative_timing
:
# Assuming bench_mode is "both" for simplicity in this file, or passed via a config
# We call the modified _print_timing to handle the display logic.
TestReporter
.
_print_timing
(
cumulative_timing
,
bench_mode
=
"both"
)
# 2. Restore PASSED OPERATORS list
if
passed
:
print
(
f
"
\n
✅ PASSED OPERATORS (
{
len
(
passed
)
}
):"
)
# Print operators, grouped (assuming 10 per line as per the old pattern)
operators
=
[
r
.
name
for
r
in
passed
]
for
i
in
range
(
0
,
len
(
operators
),
10
):
print
(
" "
+
", "
.
join
(
operators
[
i
:
i
+
10
]))
else
:
print
(
f
"
\n
✅ PASSED OPERATORS: None"
)
# 3. Restore Success Rate
if
total
>
0
:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests
=
total
-
len
(
skipped
)
if
executed_tests
>
0
:
success_rate
=
len
(
passed
)
/
executed_tests
*
100
print
(
f
"
\n
Success rate:
{
success_rate
:.
1
f
}
%"
)
if
not
failed
:
print
(
f
"
\n
🎉 All tests passed!"
)
else
:
print
(
f
"
\n
❌
{
len
(
failed
)
}
tests failed"
)
return
len
(
failed
)
==
0
# --- Internal Helpers ---
@
staticmethod
def
_print_timing
(
t
,
bench_mode
=
"both"
):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print
(
f
"
{
'-'
*
40
}
"
)
# Restore Operators Tested field using the new dataclass field
if
hasattr
(
t
,
'operators_tested'
):
print
(
f
"BENCHMARK SUMMARY:"
)
print
(
f
" Operators Tested:
{
t
.
operators_tested
}
"
)
# Restore detailed Host/Device distinction
if
bench_mode
in
[
"host"
,
"both"
]:
print
(
f
" PyTorch Host Total Time:
{
t
.
torch_host
:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Host Total Time:
{
t
.
infini_host
:
12.3
f
}
ms"
)
if
bench_mode
in
[
"device"
,
"both"
]:
print
(
f
" PyTorch Device Total Time:
{
t
.
torch_device
:
12.3
f
}
ms"
)
print
(
f
" InfiniCore Device Total Time:
{
t
.
infini_device
:
12.3
f
}
ms"
)
print
(
f
"
{
'-'
*
40
}
"
)
@
staticmethod
def
_write_smart_field
(
f
,
key
,
value
,
indent
,
sub_indent
,
close_comma
=
""
):
"""
...
...
test/infinicore/framework/types.py
0 → 100644
View file @
5aa850af
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
# TODO: Rename it, current class name is abstract.
@
dataclass
class
TestResult
:
"""Test result data structure"""
success
:
bool
return_code
:
int
# 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time
:
float
=
0.0
torch_device_time
:
float
=
0.0
infini_host_time
:
float
=
0.0
infini_device_time
:
float
=
0.0
error_message
:
str
=
""
test_case
:
Any
=
None
device
:
Any
=
None
@
dataclass
class
TestTiming
:
"""Stores performance timing metrics."""
torch_host
:
float
=
0.0
torch_device
:
float
=
0.0
infini_host
:
float
=
0.0
infini_device
:
float
=
0.0
# Added field to support the logic in your print_summary
operators_tested
:
int
=
0
@
dataclass
class
OperatorTestResult
:
"""Stores the execution results of a single operator."""
name
:
str
success
:
bool
=
False
return_code
:
int
=
-
1
error_message
:
str
=
""
stdout
:
str
=
""
stderr
:
str
=
""
timing
:
TestTiming
=
field
(
default_factory
=
TestTiming
)
@
property
def
status_icon
(
self
):
if
self
.
return_code
==
0
:
return
"✅"
if
self
.
return_code
==
-
2
:
return
"⏭️"
if
self
.
return_code
==
-
3
:
return
"⚠️"
return
"❌"
@
property
def
status_text
(
self
):
if
self
.
return_code
==
0
:
return
"PASSED"
if
self
.
return_code
==
-
2
:
return
"SKIPPED"
if
self
.
return_code
==
-
3
:
return
"PARTIAL"
return
"FAILED"
test/infinicore/run.py
View file @
5aa850af
...
...
@@ -4,14 +4,93 @@ from pathlib import Path
# Import components from the unified framework package
from
framework.loader
import
TestDiscoverer
from
framework.
executo
r
import
SingleTestExecuto
r
from
framework.
repor
ter
import
TestRepor
ter
from
framework.
data
types
import
TestTiming
from
framework.
drive
r
import
TestDrive
r
from
framework.
prin
ter
import
ConsolePrin
ter
from
framework.types
import
TestTiming
from
framework
import
get_hardware_args_group
,
add_common_test_args
def
generate_help_epilog
(
ops_dir
=
None
):
"""
Generate dynamic help epilog containing available operators and hardware platforms.
Maintains the original output format for backward compatibility.
"""
# === Adapter: Use TestDiscoverer to get operator list ===
# Temporarily instantiate a Discoverer just to fetch the list
discoverer
=
TestDiscoverer
(
ops_dir
)
operators
=
discoverer
.
get_available_operators
()
# Build epilog text (fully replicating original logic)
epilog_parts
=
[]
# Examples section
epilog_parts
.
append
(
"Examples:"
)
epilog_parts
.
append
(
" # Run all operator tests on CPU"
)
epilog_parts
.
append
(
" python run.py --cpu"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run specific operators"
)
epilog_parts
.
append
(
" python run.py --ops add matmul --nvidia"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run with debug mode on multiple devices"
)
epilog_parts
.
append
(
" python run.py --cpu --nvidia --debug"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run with verbose mode to stop on first error with full traceback"
)
epilog_parts
.
append
(
" python run.py --cpu --nvidia --verbose"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run with benchmarking (both host and device timing)"
)
epilog_parts
.
append
(
" python run.py --cpu --bench"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run with host timing only"
)
epilog_parts
.
append
(
" python run.py --nvidia --bench host"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # Run with device timing only"
)
epilog_parts
.
append
(
" python run.py --nvidia --bench device"
)
epilog_parts
.
append
(
""
)
epilog_parts
.
append
(
" # List available tests without running"
)
epilog_parts
.
append
(
" python run.py --list"
)
epilog_parts
.
append
(
""
)
# Available operators section
if
operators
:
epilog_parts
.
append
(
"Available Operators:"
)
# Group operators for better display
operators_per_line
=
4
for
i
in
range
(
0
,
len
(
operators
),
operators_per_line
):
line_ops
=
operators
[
i
:
i
+
operators_per_line
]
epilog_parts
.
append
(
f
"
{
', '
.
join
(
line_ops
)
}
"
)
epilog_parts
.
append
(
""
)
else
:
epilog_parts
.
append
(
"Available Operators: (none detected)"
)
epilog_parts
.
append
(
""
)
# Additional notes
epilog_parts
.
append
(
"Note:"
)
epilog_parts
.
append
(
" - Use '--' to pass additional arguments to individual test scripts"
)
epilog_parts
.
append
(
" - Operators are automatically discovered from the ops directory"
)
epilog_parts
.
append
(
" - --bench mode now shows cumulative timing across all operators"
)
epilog_parts
.
append
(
" - --bench host/device/both controls host/device timing measurement"
)
epilog_parts
.
append
(
" - --verbose mode stops execution on first error and shows full traceback"
)
return
"
\n
"
.
join
(
epilog_parts
)
def
main
():
"""Main entry point for the InfiniCore Operator Test Runner."""
parser
=
argparse
.
ArgumentParser
(
description
=
"Run InfiniCore operator tests across multiple hardware platforms"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Run InfiniCore operator tests across multiple hardware platforms"
,
formatter_class
=
argparse
.
RawDescriptionHelpFormatter
,
epilog
=
generate_help_epilog
()
)
parser
.
add_argument
(
"--ops-dir"
,
type
=
str
,
help
=
"Path to the ops directory (default: auto-detect)"
)
parser
.
add_argument
(
"--ops"
,
nargs
=
"+"
,
help
=
"Run specific operators only (e.g., --ops add matmul)"
)
parser
.
add_argument
(
"--list"
,
action
=
"store_true"
,
help
=
"List all available test files without running them"
)
...
...
@@ -20,7 +99,10 @@ def main():
add_common_test_args
(
parser
)
get_hardware_args_group
(
parser
)
args
,
_
=
parser
.
parse_known_args
()
args
,
unknown_args
=
parser
.
parse_known_args
()
# Show what extra arguments will be passed
if
unknown_args
:
print
(
f
"Passing extra arguments to test scripts:
{
unknown_args
}
"
)
# 1. Discovery
discoverer
=
TestDiscoverer
(
args
.
ops_dir
)
...
...
@@ -28,25 +110,62 @@ def main():
print
(
"Available operators:"
,
discoverer
.
get_available_operators
())
return
test_files
=
discoverer
.
scan
(
args
.
ops
)
if
args
.
verbose
:
print
(
f
"Verbose mode: ENABLED (will stop on first error with full traceback)"
)
if
args
.
bench
:
bench_mode
=
args
.
bench
if
args
.
bench
!=
"both"
else
"both"
print
(
f
"Benchmark mode:
{
bench_mode
.
upper
()
}
timing"
)
target_ops
=
None
if
args
.
ops
:
# Get all available operator names
available_ops
=
set
(
discoverer
.
get_available_operators
())
requested_ops
=
set
(
args
.
ops
)
# Classify using set operations
valid_ops
=
list
(
requested_ops
&
available_ops
)
# Intersection: Valid ops
invalid_ops
=
list
(
requested_ops
-
available_ops
)
# Difference: Invalid ops
# Warn if there are invalid operators
if
invalid_ops
:
print
(
f
"⚠️ Warning: The following requested operators were not found:"
)
print
(
f
"
{
', '
.
join
(
invalid_ops
)
}
"
)
print
(
f
" (Use --list to see available operators)"
)
if
not
valid_ops
:
# Case A: User input provided, but ALL were invalid.
print
(
f
"⚠️ No valid operators remained from your list."
)
print
(
f
"🔄 Fallback: Proceeding to run ALL available tests..."
)
target_ops
=
None
else
:
# Case B: At least some valid operators found.
print
(
f
"🎯 Targeted operators:
{
', '
.
join
(
valid_ops
)
}
"
)
target_ops
=
valid_ops
target_ops
=
valid_ops
test_files
=
discoverer
.
scan
(
target_ops
)
if
not
test_files
:
print
(
"No tests found."
)
sys
.
exit
(
0
)
# 2. Preparation
executor
=
SingleTestExecuto
r
()
dirver
=
TestDrive
r
()
cumulative_timing
=
TestTiming
()
printer
=
ConsolePrinter
()
results
=
[]
TestRepor
ter
.
print_header
(
discoverer
.
ops_dir
,
len
(
test_files
))
prin
ter
.
print_header
(
discoverer
.
ops_dir
,
len
(
test_files
))
# 3. Execution Loop
for
f
in
test_files
:
result
=
executor
.
run
(
f
)
result
=
dirver
.
drive
(
f
)
results
.
append
(
result
)
# Real-time reporting and printing of stdout
TestRepor
ter
.
print_live_result
(
result
,
verbose
=
args
.
verbose
)
prin
ter
.
print_live_result
(
result
,
verbose
=
args
.
verbose
)
# Accumulate timing
if
result
.
success
:
...
...
@@ -61,10 +180,12 @@ def main():
break
# 4. Final Report & Save
all_passed
=
TestRepor
ter
.
print_summary
(
all_passed
=
prin
ter
.
print_summary
(
results
,
cumulative_timing
if
args
.
bench
else
None
,
total_expected
=
len
(
test_files
)
ops_dir
=
discoverer
.
ops_dir
,
total_expected
=
len
(
test_files
),
verbose
=
args
.
verbose
)
sys
.
exit
(
0
if
all_passed
else
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment