Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5b7ef9c5
Commit
5b7ef9c5
authored
Oct 31, 2025
by
wooway777
Committed by
MaYuhang
Nov 03, 2025
Browse files
issue/540 - support more dtypes in test framework
parent
a5e20fcf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
27 deletions
+124
-27
test/infinicore/framework/datatypes.py
test/infinicore/framework/datatypes.py
+12
-0
test/infinicore/framework/tensor.py
test/infinicore/framework/tensor.py
+36
-4
test/infinicore/framework/utils.py
test/infinicore/framework/utils.py
+76
-23
No files found.
test/infinicore/framework/datatypes.py
View file @
5b7ef9c5
...
@@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype):
...
@@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype):
return
torch
.
float32
return
torch
.
float32
elif
infini_dtype
==
infinicore
.
bfloat16
:
elif
infini_dtype
==
infinicore
.
bfloat16
:
return
torch
.
bfloat16
return
torch
.
bfloat16
elif
infini_dtype
==
infinicore
.
int8
:
return
torch
.
int8
elif
infini_dtype
==
infinicore
.
int16
:
return
torch
.
int16
elif
infini_dtype
==
infinicore
.
int32
:
elif
infini_dtype
==
infinicore
.
int32
:
return
torch
.
int32
return
torch
.
int32
elif
infini_dtype
==
infinicore
.
int64
:
elif
infini_dtype
==
infinicore
.
int64
:
return
torch
.
int64
return
torch
.
int64
elif
infini_dtype
==
infinicore
.
uint8
:
return
torch
.
uint8
else
:
else
:
raise
ValueError
(
f
"Unsupported infinicore dtype:
{
infini_dtype
}
"
)
raise
ValueError
(
f
"Unsupported infinicore dtype:
{
infini_dtype
}
"
)
...
@@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype):
...
@@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype):
return
infinicore
.
float16
return
infinicore
.
float16
elif
torch_dtype
==
torch
.
bfloat16
:
elif
torch_dtype
==
torch
.
bfloat16
:
return
infinicore
.
bfloat16
return
infinicore
.
bfloat16
elif
torch_dtype
==
torch
.
int8
:
return
infinicore
.
int8
elif
torch_dtype
==
torch
.
int16
:
return
infinicore
.
int16
elif
torch_dtype
==
torch
.
int32
:
elif
torch_dtype
==
torch
.
int32
:
return
infinicore
.
int32
return
infinicore
.
int32
elif
torch_dtype
==
torch
.
int64
:
elif
torch_dtype
==
torch
.
int64
:
return
infinicore
.
int64
return
infinicore
.
int64
elif
torch_dtype
==
torch
.
uint8
:
return
infinicore
.
uint8
else
:
else
:
raise
ValueError
(
f
"Unsupported torch dtype:
{
torch_dtype
}
"
)
raise
ValueError
(
f
"Unsupported torch dtype:
{
torch_dtype
}
"
)
test/infinicore/framework/tensor.py
View file @
5b7ef9c5
import
torch
import
torch
import
infinicore
from
pathlib
import
Path
from
pathlib
import
Path
from
.datatypes
import
to_torch_dtype
from
.datatypes
import
to_torch_dtype
from
.devices
import
torch_device_map
from
.devices
import
torch_device_map
from
.utils
import
is_integer_dtype
class
TensorInitializer
:
class
TensorInitializer
:
...
@@ -38,6 +40,10 @@ class TensorInitializer:
...
@@ -38,6 +40,10 @@ class TensorInitializer:
torch_device_str
=
torch_device_map
[
device
]
torch_device_str
=
torch_device_map
[
device
]
torch_dtype
=
to_torch_dtype
(
dtype
)
torch_dtype
=
to_torch_dtype
(
dtype
)
# Handle integer types differently for random initialization
if
mode
==
TensorInitializer
.
RANDOM
and
is_integer_dtype
(
dtype
):
mode
=
TensorInitializer
.
RANDINT
# Use randint for integer types
# Handle strided tensors - calculate required storage size
# Handle strided tensors - calculate required storage size
if
strides
is
not
None
:
if
strides
is
not
None
:
# Calculate the required storage size for strided tensor
# Calculate the required storage size for strided tensor
...
@@ -61,9 +67,22 @@ class TensorInitializer:
...
@@ -61,9 +67,22 @@ class TensorInitializer:
storage_size
,
dtype
=
torch_dtype
,
device
=
torch_device_str
storage_size
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
)
elif
mode
==
TensorInitializer
.
RANDINT
:
elif
mode
==
TensorInitializer
.
RANDINT
:
# For integer types, use appropriate range
if
is_integer_dtype
(
dtype
):
if
dtype
==
infinicore
.
uint8
:
low
,
high
=
0
,
256
elif
dtype
==
infinicore
.
int8
:
low
,
high
=
-
128
,
128
elif
dtype
==
infinicore
.
int16
:
low
,
high
=
-
32768
,
32768
else
:
# int32, int64, uint32
low
,
high
=
-
1000
,
1000
else
:
low
,
high
=
-
1000
,
1000
base_tensor
=
torch
.
randint
(
base_tensor
=
torch
.
randint
(
-
2000000000
,
low
,
2000000000
,
high
,
(
storage_size
,),
(
storage_size
,),
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
torch_device_str
,
device
=
torch_device_str
,
...
@@ -92,9 +111,22 @@ class TensorInitializer:
...
@@ -92,9 +111,22 @@ class TensorInitializer:
elif
mode
==
TensorInitializer
.
ONES
:
elif
mode
==
TensorInitializer
.
ONES
:
tensor
=
torch
.
ones
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
tensor
=
torch
.
ones
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
elif
mode
==
TensorInitializer
.
RANDINT
:
elif
mode
==
TensorInitializer
.
RANDINT
:
# For integer types, use appropriate range
if
is_integer_dtype
(
dtype
):
if
dtype
==
infinicore
.
uint8
:
low
,
high
=
0
,
256
elif
dtype
==
infinicore
.
int8
:
low
,
high
=
-
128
,
128
elif
dtype
==
infinicore
.
int16
:
low
,
high
=
-
32768
,
32768
else
:
# int32, int64, uint32
low
,
high
=
-
1000
,
1000
else
:
low
,
high
=
-
1000
,
1000
tensor
=
torch
.
randint
(
tensor
=
torch
.
randint
(
-
2000000000
,
low
,
2000000000
,
high
,
shape
,
shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
torch_device_str
,
device
=
torch_device_str
,
...
...
test/infinicore/framework/utils.py
View file @
5b7ef9c5
...
@@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
...
@@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
def
debug
(
actual
,
desired
,
atol
=
0
,
rtol
=
1e-2
,
equal_nan
=
False
,
verbose
=
True
):
def
is_integer_dtype
(
dtype
):
"""Check if dtype is integer type"""
return
dtype
in
[
infinicore
.
int8
,
infinicore
.
int16
,
infinicore
.
int32
,
infinicore
.
int64
,
infinicore
.
uint8
,
]
def
is_float_dtype
(
dtype
):
"""Check if dtype is floating point type"""
return
dtype
in
[
infinicore
.
float16
,
infinicore
.
float32
,
infinicore
.
bfloat16
]
def
debug
(
actual
,
desired
,
atol
=
0
,
rtol
=
1e-2
,
equal_nan
=
False
,
verbose
=
True
,
dtype
=
None
):
"""
"""
Debug function to compare two tensors and print differences
Debug function to compare two tensors and print differences
"""
"""
# Convert to float32 for bfloat16 comparison
if
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
if
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
actual
=
actual
.
to
(
torch
.
float32
)
actual
=
actual
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
)
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
,
dtype
)
import
numpy
as
np
# Use appropriate comparison based on dtype
if
dtype
and
is_integer_dtype
(
dtype
):
# For integer types, require exact equality
import
numpy
as
np
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_array_equal
(
actual
.
cpu
(),
desired
.
cpu
())
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
else
:
)
# For float types, use allclose
import
numpy
as
np
np
.
testing
.
assert_allclose
(
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
)
def
print_discrepancy
(
def
print_discrepancy
(
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
equal_nan
=
True
,
verbose
=
True
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
equal_nan
=
True
,
verbose
=
True
,
dtype
=
None
):
):
"""Print detailed tensor differences"""
"""Print detailed tensor differences"""
if
actual
.
shape
!=
expected
.
shape
:
if
actual
.
shape
!=
expected
.
shape
:
...
@@ -69,13 +96,21 @@ def print_discrepancy(
...
@@ -69,13 +96,21 @@ def print_discrepancy(
actual_isnan
=
torch
.
isnan
(
actual
)
actual_isnan
=
torch
.
isnan
(
actual
)
expected_isnan
=
torch
.
isnan
(
expected
)
expected_isnan
=
torch
.
isnan
(
expected
)
# Calculate difference mask
# Calculate difference mask based on dtype
nan_mismatch
=
(
if
dtype
and
is_integer_dtype
(
dtype
):
actual_isnan
^
expected_isnan
if
equal_nan
else
actual_isnan
|
expected_isnan
# For integer types, exact equality required
)
diff_mask
=
actual
!=
expected
diff_mask
=
nan_mismatch
|
(
else
:
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
))
# For float types, use tolerance-based comparison
)
nan_mismatch
=
(
actual_isnan
^
expected_isnan
if
equal_nan
else
actual_isnan
|
expected_isnan
)
diff_mask
=
nan_mismatch
|
(
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
))
)
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
delta
=
actual
-
expected
delta
=
actual
-
expected
...
@@ -107,8 +142,9 @@ def print_discrepancy(
...
@@ -107,8 +142,9 @@ def print_discrepancy(
print
(
f
" - Actual dtype:
{
actual
.
dtype
}
"
)
print
(
f
" - Actual dtype:
{
actual
.
dtype
}
"
)
print
(
f
" - Desired dtype:
{
expected
.
dtype
}
"
)
print
(
f
" - Desired dtype:
{
expected
.
dtype
}
"
)
print
(
f
" - Atol:
{
atol
}
"
)
if
not
(
dtype
and
is_integer_dtype
(
dtype
)):
print
(
f
" - Rtol:
{
rtol
}
"
)
print
(
f
" - Atol:
{
atol
}
"
)
print
(
f
" - Rtol:
{
rtol
}
"
)
print
(
print
(
f
" - Mismatched elements:
{
len
(
diff_indices
)
}
/
{
actual
.
numel
()
}
(
{
len
(
diff_indices
)
/
actual
.
numel
()
*
100
}
%)"
f
" - Mismatched elements:
{
len
(
diff_indices
)
}
/
{
actual
.
numel
()
}
(
{
len
(
diff_indices
)
/
actual
.
numel
()
*
100
}
%)"
)
)
...
@@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
...
@@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
"""
"""
Get tolerance settings based on data type
Get tolerance settings based on data type
"""
"""
# For integer types, return zero tolerance (exact match required)
if
is_integer_dtype
(
tensor_dtype
):
return
0
,
0
tolerance
=
tolerance_map
.
get
(
tolerance
=
tolerance_map
.
get
(
tensor_dtype
,
{
"atol"
:
default_atol
,
"rtol"
:
default_rtol
}
tensor_dtype
,
{
"atol"
:
default_atol
,
"rtol"
:
default_rtol
}
)
)
...
@@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
...
@@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
Args:
Args:
infini_result: infinicore tensor result
infini_result: infinicore tensor result
torch_reference: PyTorch tensor reference (for shape and device)
torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
device_str: torch device string
Returns:
Returns:
torch.Tensor: PyTorch tensor with infinicore data
torch.Tensor: PyTorch tensor with infinicore data
...
@@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
...
@@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
def
compare_results
(
def
compare_results
(
infini_result
,
torch_result
,
atol
=
1e-5
,
rtol
=
1e-5
,
debug_mode
=
False
infini_result
,
torch_result
,
atol
=
1e-5
,
rtol
=
1e-5
,
debug_mode
=
False
,
dtype
=
None
):
):
"""
"""
Generic function to compare infinicore result with PyTorch reference result
Generic function to compare infinicore result with PyTorch reference result
...
@@ -190,6 +228,7 @@ def compare_results(
...
@@ -190,6 +228,7 @@ def compare_results(
atol: absolute tolerance
atol: absolute tolerance
rtol: relative tolerance
rtol: relative tolerance
debug_mode: whether to enable debug output
debug_mode: whether to enable debug output
dtype: infinicore data type for comparison logic
Returns:
Returns:
bool: True if results match within tolerance
bool: True if results match within tolerance
...
@@ -197,12 +236,21 @@ def compare_results(
...
@@ -197,12 +236,21 @@ def compare_results(
# Convert infinicore result to PyTorch tensor for comparison
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini
=
convert_infinicore_to_torch
(
infini_result
,
torch_result
)
torch_result_from_infini
=
convert_infinicore_to_torch
(
infini_result
,
torch_result
)
# Choose comparison method based on dtype
if
dtype
and
is_integer_dtype
(
dtype
):
# For integer types, require exact equality
result
=
torch
.
equal
(
torch_result_from_infini
,
torch_result
)
else
:
# For float types, use tolerance-based comparison
result
=
torch
.
allclose
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
)
# Debug mode: detailed comparison
# Debug mode: detailed comparison
if
debug_mode
:
if
debug_mode
:
debug
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
)
debug
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
,
dtype
=
dtype
)
# Check if results match within tolerance
return
result
return
torch
.
allclose
(
torch_result_from_infini
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
)
def
create_test_comparator
(
config
,
dtype
,
tolerance_map
=
None
,
mode_name
=
""
):
def
create_test_comparator
(
config
,
dtype
,
tolerance_map
=
None
,
mode_name
=
""
):
...
@@ -227,7 +275,12 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
...
@@ -227,7 +275,12 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
if
config
.
debug
and
mode_name
:
if
config
.
debug
and
mode_name
:
print
(
f
"
\n\033
[94mDEBUG INFO -
{
mode_name
}
:
\033
[0m"
)
print
(
f
"
\n\033
[94mDEBUG INFO -
{
mode_name
}
:
\033
[0m"
)
return
compare_results
(
return
compare_results
(
infini_result
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
,
debug_mode
=
config
.
debug
infini_result
,
torch_result
,
atol
=
atol
,
rtol
=
rtol
,
debug_mode
=
config
.
debug
,
dtype
=
dtype
,
)
)
return
compare_test_results
return
compare_test_results
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment