Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a999ed68
Commit
a999ed68
authored
Nov 12, 2025
by
wooway777
Committed by
MaYuhang
Nov 13, 2025
Browse files
issue/573 - support more data types
parent
7542c51d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
4 deletions
+99
-4
test/infinicore/framework/__init__.py
test/infinicore/framework/__init__.py
+7
-0
test/infinicore/framework/datatypes.py
test/infinicore/framework/datatypes.py
+12
-0
test/infinicore/framework/tensor.py
test/infinicore/framework/tensor.py
+42
-1
test/infinicore/framework/utils.py
test/infinicore/framework/utils.py
+38
-3
No files found.
test/infinicore/framework/__init__.py
View file @
a999ed68
...
@@ -9,6 +9,9 @@ from .utils import (
...
@@ -9,6 +9,9 @@ from .utils import (
profile_operation
,
profile_operation
,
rearrange_tensor
,
rearrange_tensor
,
convert_infinicore_to_torch
,
convert_infinicore_to_torch
,
is_integer_dtype
,
is_complex_dtype
,
is_floating_dtype
,
)
)
from
.config
import
(
from
.config
import
(
get_args
,
get_args
,
...
@@ -46,4 +49,8 @@ __all__ = [
...
@@ -46,4 +49,8 @@ __all__ = [
"to_infinicore_dtype"
,
"to_infinicore_dtype"
,
"to_torch_dtype"
,
"to_torch_dtype"
,
"torch_device_map"
,
"torch_device_map"
,
# Type checking utilities
"is_integer_dtype"
,
"is_complex_dtype"
,
"is_floating_dtype"
,
]
]
test/infinicore/framework/datatypes.py
View file @
a999ed68
...
@@ -8,6 +8,8 @@ def to_torch_dtype(infini_dtype):
...
@@ -8,6 +8,8 @@ def to_torch_dtype(infini_dtype):
return
torch
.
float16
return
torch
.
float16
elif
infini_dtype
==
infinicore
.
float32
:
elif
infini_dtype
==
infinicore
.
float32
:
return
torch
.
float32
return
torch
.
float32
elif
infini_dtype
==
infinicore
.
float64
:
return
torch
.
float64
elif
infini_dtype
==
infinicore
.
bfloat16
:
elif
infini_dtype
==
infinicore
.
bfloat16
:
return
torch
.
bfloat16
return
torch
.
bfloat16
elif
infini_dtype
==
infinicore
.
int8
:
elif
infini_dtype
==
infinicore
.
int8
:
...
@@ -22,6 +24,10 @@ def to_torch_dtype(infini_dtype):
...
@@ -22,6 +24,10 @@ def to_torch_dtype(infini_dtype):
return
torch
.
uint8
return
torch
.
uint8
elif
infini_dtype
==
infinicore
.
bool
:
elif
infini_dtype
==
infinicore
.
bool
:
return
torch
.
bool
return
torch
.
bool
elif
infini_dtype
==
infinicore
.
complex64
:
return
torch
.
complex64
elif
infini_dtype
==
infinicore
.
complex128
:
return
torch
.
complex128
else
:
else
:
raise
ValueError
(
f
"Unsupported infinicore dtype:
{
infini_dtype
}
"
)
raise
ValueError
(
f
"Unsupported infinicore dtype:
{
infini_dtype
}
"
)
...
@@ -30,6 +36,8 @@ def to_infinicore_dtype(torch_dtype):
...
@@ -30,6 +36,8 @@ def to_infinicore_dtype(torch_dtype):
"""Convert PyTorch data type to infinicore data type"""
"""Convert PyTorch data type to infinicore data type"""
if
torch_dtype
==
torch
.
float32
:
if
torch_dtype
==
torch
.
float32
:
return
infinicore
.
float32
return
infinicore
.
float32
elif
torch_dtype
==
torch
.
float64
:
return
infinicore
.
float64
elif
torch_dtype
==
torch
.
float16
:
elif
torch_dtype
==
torch
.
float16
:
return
infinicore
.
float16
return
infinicore
.
float16
elif
torch_dtype
==
torch
.
bfloat16
:
elif
torch_dtype
==
torch
.
bfloat16
:
...
@@ -46,5 +54,9 @@ def to_infinicore_dtype(torch_dtype):
...
@@ -46,5 +54,9 @@ def to_infinicore_dtype(torch_dtype):
return
infinicore
.
uint8
return
infinicore
.
uint8
elif
torch_dtype
==
torch
.
bool
:
elif
torch_dtype
==
torch
.
bool
:
return
infinicore
.
bool
return
infinicore
.
bool
elif
torch_dtype
==
torch
.
complex64
:
return
infinicore
.
complex64
elif
torch_dtype
==
torch
.
complex128
:
return
infinicore
.
complex128
else
:
else
:
raise
ValueError
(
f
"Unsupported torch dtype:
{
torch_dtype
}
"
)
raise
ValueError
(
f
"Unsupported torch dtype:
{
torch_dtype
}
"
)
test/infinicore/framework/tensor.py
View file @
a999ed68
import
torch
import
torch
import
math
from
pathlib
import
Path
from
pathlib
import
Path
from
.datatypes
import
to_torch_dtype
from
.datatypes
import
to_torch_dtype
from
.devices
import
torch_device_map
from
.devices
import
torch_device_map
from
.utils
import
is_integer_dtype
from
.utils
import
is_integer_dtype
,
is_complex_dtype
class
TensorInitializer
:
class
TensorInitializer
:
...
@@ -52,7 +53,12 @@ class TensorInitializer:
...
@@ -52,7 +53,12 @@ class TensorInitializer:
return
TensorInitializer
.
_create_integer_tensor
(
return
TensorInitializer
.
_create_integer_tensor
(
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
)
)
elif
is_complex_dtype
(
torch_dtype
):
return
TensorInitializer
.
_create_complex_tensor
(
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
)
# Handle real floating-point types
if
mode
==
TensorInitializer
.
RANDOM
:
if
mode
==
TensorInitializer
.
RANDOM
:
return
torch
.
rand
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
return
torch
.
rand
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
elif
mode
==
TensorInitializer
.
ZEROS
:
elif
mode
==
TensorInitializer
.
ZEROS
:
...
@@ -88,6 +94,7 @@ class TensorInitializer:
...
@@ -88,6 +94,7 @@ class TensorInitializer:
@
staticmethod
@
staticmethod
def
_create_integer_tensor
(
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
):
def
_create_integer_tensor
(
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
):
"""Create integer tensor"""
if
mode
==
TensorInitializer
.
RANDOM
:
if
mode
==
TensorInitializer
.
RANDOM
:
if
torch_dtype
==
torch
.
bool
:
if
torch_dtype
==
torch
.
bool
:
return
torch
.
randint
(
return
torch
.
randint
(
...
@@ -135,6 +142,40 @@ class TensorInitializer:
...
@@ -135,6 +142,40 @@ class TensorInitializer:
0
,
100
,
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
0
,
100
,
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
)
@
staticmethod
def
_create_complex_tensor
(
shape
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
):
"""Create complex tensor (complex64 or complex128)"""
if
mode
==
TensorInitializer
.
RANDOM
:
# Create complex tensor with random real and imaginary parts
real_part
=
torch
.
rand
(
shape
,
device
=
torch_device_str
)
imag_part
=
torch
.
rand
(
shape
,
device
=
torch_device_str
)
complex_tensor
=
torch
.
complex
(
real_part
,
imag_part
)
return
complex_tensor
.
to
(
torch_dtype
)
elif
mode
==
TensorInitializer
.
ZEROS
:
return
torch
.
zeros
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
elif
mode
==
TensorInitializer
.
ONES
:
return
torch
.
ones
(
shape
,
dtype
=
torch_dtype
,
device
=
torch_device_str
)
elif
mode
==
TensorInitializer
.
MANUAL
:
tensor
=
kwargs
.
get
(
"set_tensor"
)
if
tensor
is
None
:
raise
ValueError
(
"Manual mode requires set_tensor"
)
if
list
(
tensor
.
shape
)
!=
list
(
shape
):
raise
ValueError
(
f
"Shape mismatch: expected
{
shape
}
, got
{
tensor
.
shape
}
"
)
return
tensor
.
to
(
torch_dtype
).
to
(
torch_device_str
)
elif
mode
==
TensorInitializer
.
BINARY
:
tensor
=
kwargs
.
get
(
"set_tensor"
)
if
tensor
is
None
:
raise
ValueError
(
"Binary mode requires set_tensor"
)
return
tensor
.
to
(
torch_dtype
).
to
(
torch_device_str
)
else
:
# Default to random complex values
real_part
=
torch
.
rand
(
shape
,
device
=
torch_device_str
)
imag_part
=
torch
.
rand
(
shape
,
device
=
torch_device_str
)
complex_tensor
=
torch
.
complex
(
real_part
,
imag_part
)
return
complex_tensor
.
to
(
torch_dtype
)
@
staticmethod
@
staticmethod
def
_create_strided_tensor
(
def
_create_strided_tensor
(
shape
,
strides
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
shape
,
strides
,
torch_dtype
,
torch_device_str
,
mode
,
**
kwargs
...
...
test/infinicore/framework/utils.py
View file @
a999ed68
...
@@ -42,7 +42,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
...
@@ -42,7 +42,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
"""
Debug function to compare two tensors and print differences
Debug function to compare two tensors and print differences
"""
"""
if
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
# Handle complex types by converting to real representation for comparison
if
actual
.
is_complex
()
or
desired
.
is_complex
():
actual
=
torch
.
view_as_real
(
actual
)
desired
=
torch
.
view_as_real
(
desired
)
elif
actual
.
dtype
==
torch
.
bfloat16
or
desired
.
dtype
==
torch
.
bfloat16
:
actual
=
actual
.
to
(
torch
.
float32
)
actual
=
actual
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
desired
=
desired
.
to
(
torch
.
float32
)
...
@@ -162,8 +166,6 @@ def convert_infinicore_to_torch(infini_result):
...
@@ -162,8 +166,6 @@ def convert_infinicore_to_torch(infini_result):
Args:
Args:
infini_result: infinicore tensor result
infini_result: infinicore tensor result
dtype: infinicore data type
device_str: torch device string
Returns:
Returns:
torch.Tensor: PyTorch tensor with infinicore data
torch.Tensor: PyTorch tensor with infinicore data
...
@@ -259,6 +261,24 @@ def compare_results(
...
@@ -259,6 +261,24 @@ def compare_results(
if
debug_mode
and
not
result_equal
:
if
debug_mode
and
not
result_equal
:
print
(
"Integer tensor comparison failed - requiring exact equality"
)
print
(
"Integer tensor comparison failed - requiring exact equality"
)
return
result_equal
return
result_equal
elif
is_complex_dtype
(
torch_result_from_infini
.
dtype
)
or
is_complex_dtype
(
torch_result
.
dtype
):
# Complex number comparison - compare real and imaginary parts separately
real_close
=
torch
.
allclose
(
torch_result_from_infini
.
real
,
torch_result
.
real
,
atol
=
atol
,
rtol
=
rtol
)
imag_close
=
torch
.
allclose
(
torch_result_from_infini
.
imag
,
torch_result
.
imag
,
atol
=
atol
,
rtol
=
rtol
)
result_equal
=
real_close
and
imag_close
if
debug_mode
and
not
result_equal
:
print
(
"Complex tensor comparison failed"
)
if
not
real_close
:
print
(
" Real parts don't match"
)
if
not
imag_close
:
print
(
" Imaginary parts don't match"
)
return
result_equal
else
:
else
:
# Tolerance-based comparison for floating-point types
# Tolerance-based comparison for floating-point types
return
torch
.
allclose
(
return
torch
.
allclose
(
...
@@ -382,3 +402,18 @@ def is_integer_dtype(dtype):
...
@@ -382,3 +402,18 @@ def is_integer_dtype(dtype):
torch
.
uint8
,
torch
.
uint8
,
torch
.
bool
,
torch
.
bool
,
]
]
def
is_complex_dtype
(
dtype
):
"""Check if dtype is complex type"""
return
dtype
in
[
torch
.
complex64
,
torch
.
complex128
]
def
is_floating_dtype
(
dtype
):
"""Check if dtype is floating-point type"""
return
dtype
in
[
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
bfloat16
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment