Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
2e5b2342
Unverified
Commit
2e5b2342
authored
Nov 06, 2025
by
PanZezhong1725
Committed by
GitHub
Nov 06, 2025
Browse files
issue/547 - improved test output (#550)
parents
bf3395f5
991f534c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
29 deletions
+54
-29
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+51
-26
test/infinicore/framework/utils.py
test/infinicore/framework/utils.py
+3
-3
No files found.
test/infinicore/framework/base.py
View file @
2e5b2342
...
...
@@ -63,14 +63,22 @@ class TestCase:
if
inp
.
init_mode
!=
TensorInitializer
.
RANDOM
else
""
)
if
hasattr
(
inp
,
"is_contiguous"
)
and
not
inp
.
is_contiguous
:
input_strs
.
append
(
f
"strided_tensor
{
inp
.
shape
}{
dtype_str
}{
init_str
}
"
)
# Show shape and strides for non-contiguous tensors
if
(
hasattr
(
inp
,
"is_contiguous"
)
and
not
inp
.
is_contiguous
and
inp
.
strides
):
strides_str
=
f
", strides=
{
inp
.
strides
}
"
input_strs
.
append
(
f
"tensor
{
inp
.
shape
}{
strides_str
}{
dtype_str
}{
init_str
}
"
)
else
:
input_strs
.
append
(
f
"tensor
{
inp
.
shape
}{
dtype_str
}{
init_str
}
"
)
else
:
input_strs
.
append
(
str
(
inp
))
base_str
=
f
"TestCase(mode=
{
mode_str
}
, inputs=[
{
'
,
'
.
join
(
input_strs
)
}
]"
base_str
=
f
"TestCase(mode=
{
mode_str
}
, inputs=[
{
'
;
'
.
join
(
input_strs
)
}
]"
if
self
.
output
:
dtype_str
=
f
", dtype=
{
self
.
output
.
dtype
}
"
if
self
.
output
.
dtype
else
""
init_str
=
(
...
...
@@ -78,7 +86,16 @@ class TestCase:
if
self
.
output
.
init_mode
!=
TensorInitializer
.
RANDOM
else
""
)
base_str
+=
f
", output=tensor
{
self
.
output
.
shape
}{
dtype_str
}{
init_str
}
"
# Show shape and strides for non-contiguous output tensors
if
(
hasattr
(
self
.
output
,
"is_contiguous"
)
and
not
self
.
output
.
is_contiguous
and
self
.
output
.
strides
):
strides_str
=
f
", strides=
{
self
.
output
.
strides
}
"
base_str
+=
f
", output=tensor
{
self
.
output
.
shape
}{
strides_str
}{
dtype_str
}{
init_str
}
"
else
:
base_str
+=
f
", output=tensor
{
self
.
output
.
shape
}{
dtype_str
}{
init_str
}
"
if
self
.
kwargs
:
base_str
+=
f
", kwargs=
{
self
.
kwargs
}
"
if
self
.
description
:
...
...
@@ -131,24 +148,30 @@ class TestRunner:
if
self
.
config
.
dtype_combinations
:
for
dtype_combo
in
self
.
config
.
dtype_combinations
:
try
:
test_func
(
device
,
test
_
case
,
dtype_combo
,
self
.
config
)
# Print
test
case
info first
combo_str
=
self
.
_format_dtype_combo
(
dtype_combo
)
print
(
f
"✓
{
test_case
}
with
{
combo_str
}
passed"
)
print
(
f
"
{
test_case
}
with
{
combo_str
}
"
)
test_func
(
device
,
test_case
,
dtype_combo
,
self
.
config
)
print
(
f
"
\033
[92m✓
\033
[0m Passed"
)
except
Exception
as
e
:
combo_str
=
self
.
_format_dtype_combo
(
dtype_combo
)
error_msg
=
f
"
{
test_case
}
with
{
combo_str
}
on
{
InfiniDeviceNames
[
device
]
}
:
{
e
}
"
print
(
f
"
✗
{
error_msg
}
"
)
error_msg
=
f
"
Error
:
{
e
}
"
print
(
f
"
\033
[91m✗
\033
[0m
{
error_msg
}
"
)
self
.
failed_tests
.
append
(
error_msg
)
if
self
.
config
.
debug
:
raise
else
:
for
dtype
in
tensor_dtypes
:
try
:
# Print test case info first
print
(
f
"
{
test_case
}
with
{
dtype
}
"
)
test_func
(
device
,
test_case
,
dtype
,
self
.
config
)
print
(
f
"
✓
{
test_case
}
with
{
dtype
}
p
assed"
)
print
(
f
"
\033
[92m✓
\033
[0m P
assed"
)
except
Exception
as
e
:
error_msg
=
f
"
{
test_case
}
with
{
dtype
}
on
{
InfiniDeviceNames
[
device
]
}
:
{
e
}
"
print
(
f
"
✗
{
error_msg
}
"
)
error_msg
=
f
"
Error
:
{
e
}
"
print
(
f
"
\033
[91m✗
\033
[0m
{
error_msg
}
"
)
self
.
failed_tests
.
append
(
error_msg
)
if
self
.
config
.
debug
:
raise
...
...
@@ -214,7 +237,7 @@ class BaseOperatorTest(ABC):
raise
NotImplementedError
(
"torch_operator not implemented"
)
def
infinicore_operator
(
self
,
*
inputs
,
out
=
None
,
**
kwargs
):
"""Unified Infini
c
ore operator function - can be overridden or return None"""
"""Unified Infini
C
ore operator function - can be overridden or return None"""
raise
NotImplementedError
(
"infinicore_operator not implemented"
)
def
create_strided_tensor
(
...
...
@@ -321,9 +344,7 @@ class BaseOperatorTest(ABC):
# If neither operator is implemented, skip the test
if
not
torch_implemented
and
not
infini_implemented
:
print
(
f
"⚠
{
self
.
operator_name
}
{
mode_name
}
: Both operators not implemented - test skipped"
)
print
(
f
"⚠ Both operators not implemented - test skipped"
)
return
# If only one operator is implemented, run it without comparison
...
...
@@ -332,7 +353,7 @@ class BaseOperatorTest(ABC):
"torch_operator"
if
not
torch_implemented
else
"infinicore_operator"
)
print
(
f
"⚠
{
self
.
operator_name
}
{
mode_name
}
:
{
missing_op
}
not implemented - running single operator without comparison"
f
"⚠
{
missing_op
}
not implemented - running single operator without comparison"
)
# Run the available operator for benchmarking if requested
...
...
@@ -342,8 +363,9 @@ class BaseOperatorTest(ABC):
def
torch_op
():
return
self
.
torch_operator
(
*
inputs
,
**
kwargs
)
print
(
f
"
{
mode_name
}
:"
)
profile_operation
(
f
"PyTorch
{
self
.
operator_name
}
{
mode_name
}
"
,
"PyTorch
"
,
torch_op
,
device_str
,
config
.
num_prerun
,
...
...
@@ -354,8 +376,9 @@ class BaseOperatorTest(ABC):
def
infini_op
():
return
self
.
infinicore_operator
(
*
infini_inputs
,
**
kwargs
)
print
(
f
"
{
mode_name
}
:"
)
profile_operation
(
f
"Infini
core
{
self
.
operator_name
}
{
mode_name
}
"
,
"Infini
Core
"
,
infini_op
,
device_str
,
config
.
num_prerun
,
...
...
@@ -388,21 +411,22 @@ class BaseOperatorTest(ABC):
)
compare_fn
=
create_test_comparator
(
config
,
comparison_dtype
,
mode_name
=
f
"
{
self
.
operator_name
}
{
mode_name
}
"
config
,
comparison_dtype
,
mode_name
=
f
"
{
mode_name
}
"
)
is_valid
=
compare_fn
(
infini_result
,
torch_result
)
assert
is_valid
,
f
"
{
self
.
operator_name
}
{
mode_name
}
t
es
t
failed"
assert
is_valid
,
f
"
{
mode_name
}
r
es
ult comparison
failed"
if
config
.
bench
:
print
(
f
"
{
mode_name
}
:"
)
profile_operation
(
f
"PyTorch
{
self
.
operator_name
}
{
mode_name
}
"
,
"PyTorch
"
,
torch_op
,
device_str
,
config
.
num_prerun
,
config
.
num_iterations
,
)
profile_operation
(
f
"Infini
core
{
self
.
operator_name
}
{
mode_name
}
"
,
"Infini
Core
"
,
infini_op
,
device_str
,
config
.
num_prerun
,
...
...
@@ -464,21 +488,22 @@ class BaseOperatorTest(ABC):
test_case
,
dtype_config
,
torch_output
)
compare_fn
=
create_test_comparator
(
config
,
comparison_dtype
,
mode_name
=
f
"
{
self
.
operator_name
}
{
mode_name
}
"
config
,
comparison_dtype
,
mode_name
=
f
"
{
mode_name
}
"
)
is_valid
=
compare_fn
(
infini_output
,
torch_output
)
assert
is_valid
,
f
"
{
self
.
operator_name
}
{
mode_name
}
t
es
t
failed"
assert
is_valid
,
f
"
{
mode_name
}
r
es
ult comparison
failed"
if
config
.
bench
:
print
(
f
"
{
mode_name
}
:"
)
profile_operation
(
f
"PyTorch
{
self
.
operator_name
}
{
mode_name
}
"
,
"PyTorch
"
,
torch_op_inplace
,
device_str
,
config
.
num_prerun
,
config
.
num_iterations
,
)
profile_operation
(
f
"Infini
core
{
self
.
operator_name
}
{
mode_name
}
"
,
"Infini
Core
"
,
infini_op_inplace
,
device_str
,
config
.
num_prerun
,
...
...
test/infinicore/framework/utils.py
View file @
2e5b2342
...
...
@@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
# Timed execution
elapsed
=
timed_op
(
lambda
:
func
(),
num_iterations
,
torch_device
)
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
def
is_integer_dtype
(
dtype
):
...
...
@@ -157,7 +157,7 @@ def print_discrepancy(
print
(
f
" - Min(delta) :
{
torch
.
min
(
delta
):
<
{
col_width
[
1
]
}}
| Max(delta) :
{
torch
.
max
(
delta
):
<
{
col_width
[
2
]
}}
"
)
print
(
"-"
*
total_width
+
"
\n
"
)
print
(
"-"
*
total_width
)
return
diff_indices
...
...
@@ -273,7 +273,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
def
compare_test_results
(
infini_result
,
torch_result
):
if
config
.
debug
and
mode_name
:
print
(
f
"
\
n\
033
[94mDEBUG INFO -
{
mode_name
}
:
\033
[0m"
)
print
(
f
"
\033
[94mDEBUG INFO -
{
mode_name
}
:
\033
[0m"
)
return
compare_results
(
infini_result
,
torch_result
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment