Commit a999ed68 authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/573 - support more data types

parent 7542c51d
......@@ -9,6 +9,9 @@ from .utils import (
profile_operation,
rearrange_tensor,
convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
)
from .config import (
get_args,
......@@ -46,4 +49,8 @@ __all__ = [
"to_infinicore_dtype",
"to_torch_dtype",
"torch_device_map",
# Type checking utilities
"is_integer_dtype",
"is_complex_dtype",
"is_floating_dtype",
]
......@@ -8,6 +8,8 @@ def to_torch_dtype(infini_dtype):
return torch.float16
elif infini_dtype == infinicore.float32:
return torch.float32
elif infini_dtype == infinicore.float64:
return torch.float64
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
......@@ -22,6 +24,10 @@ def to_torch_dtype(infini_dtype):
return torch.uint8
elif infini_dtype == infinicore.bool:
return torch.bool
elif infini_dtype == infinicore.complex64:
return torch.complex64
elif infini_dtype == infinicore.complex128:
return torch.complex128
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
......@@ -30,6 +36,8 @@ def to_infinicore_dtype(torch_dtype):
"""Convert PyTorch data type to infinicore data type"""
if torch_dtype == torch.float32:
return infinicore.float32
elif torch_dtype == torch.float64:
return infinicore.float64
elif torch_dtype == torch.float16:
return infinicore.float16
elif torch_dtype == torch.bfloat16:
......@@ -46,5 +54,9 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.uint8
elif torch_dtype == torch.bool:
return infinicore.bool
elif torch_dtype == torch.complex64:
return infinicore.complex64
elif torch_dtype == torch.complex128:
return infinicore.complex128
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
import torch
import math
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype
from .utils import is_integer_dtype, is_complex_dtype
class TensorInitializer:
......@@ -52,7 +53,12 @@ class TensorInitializer:
return TensorInitializer._create_integer_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs
)
elif is_complex_dtype(torch_dtype):
return TensorInitializer._create_complex_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs
)
# Handle real floating-point types
if mode == TensorInitializer.RANDOM:
return torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS:
......@@ -88,6 +94,7 @@ class TensorInitializer:
@staticmethod
def _create_integer_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
"""Create integer tensor"""
if mode == TensorInitializer.RANDOM:
if torch_dtype == torch.bool:
return torch.randint(
......@@ -135,6 +142,40 @@ class TensorInitializer:
0, 100, shape, dtype=torch_dtype, device=torch_device_str
)
@staticmethod
def _create_complex_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
"""Create complex tensor (complex64 or complex128)"""
if mode == TensorInitializer.RANDOM:
# Create complex tensor with random real and imaginary parts
real_part = torch.rand(shape, device=torch_device_str)
imag_part = torch.rand(shape, device=torch_device_str)
complex_tensor = torch.complex(real_part, imag_part)
return complex_tensor.to(torch_dtype)
elif mode == TensorInitializer.ZEROS:
return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.MANUAL:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Manual mode requires set_tensor")
if list(tensor.shape) != list(shape):
raise ValueError(
f"Shape mismatch: expected {shape}, got {tensor.shape}"
)
return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Binary mode requires set_tensor")
return tensor.to(torch_dtype).to(torch_device_str)
else:
# Default to random complex values
real_part = torch.rand(shape, device=torch_device_str)
imag_part = torch.rand(shape, device=torch_device_str)
complex_tensor = torch.complex(real_part, imag_part)
return complex_tensor.to(torch_dtype)
@staticmethod
def _create_strided_tensor(
shape, strides, torch_dtype, torch_device_str, mode, **kwargs
......
......@@ -42,7 +42,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
......@@ -162,8 +166,6 @@ def convert_infinicore_to_torch(infini_result):
Args:
infini_result: infinicore tensor result
dtype: infinicore data type
device_str: torch device string
Returns:
torch.Tensor: PyTorch tensor with infinicore data
......@@ -259,6 +261,24 @@ def compare_results(
if debug_mode and not result_equal:
print("Integer tensor comparison failed - requiring exact equality")
return result_equal
elif is_complex_dtype(torch_result_from_infini.dtype) or is_complex_dtype(
torch_result.dtype
):
# Complex number comparison - compare real and imaginary parts separately
real_close = torch.allclose(
torch_result_from_infini.real, torch_result.real, atol=atol, rtol=rtol
)
imag_close = torch.allclose(
torch_result_from_infini.imag, torch_result.imag, atol=atol, rtol=rtol
)
result_equal = real_close and imag_close
if debug_mode and not result_equal:
print("Complex tensor comparison failed")
if not real_close:
print(" Real parts don't match")
if not imag_close:
print(" Imaginary parts don't match")
return result_equal
else:
# Tolerance-based comparison for floating-point types
return torch.allclose(
......@@ -382,3 +402,18 @@ def is_integer_dtype(dtype):
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment