Commit a999ed68 authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/573 - support more data types

parent 7542c51d
...@@ -9,6 +9,9 @@ from .utils import ( ...@@ -9,6 +9,9 @@ from .utils import (
profile_operation, profile_operation,
rearrange_tensor, rearrange_tensor,
convert_infinicore_to_torch, convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
) )
from .config import ( from .config import (
get_args, get_args,
...@@ -46,4 +49,8 @@ __all__ = [ ...@@ -46,4 +49,8 @@ __all__ = [
"to_infinicore_dtype", "to_infinicore_dtype",
"to_torch_dtype", "to_torch_dtype",
"torch_device_map", "torch_device_map",
# Type checking utilities
"is_integer_dtype",
"is_complex_dtype",
"is_floating_dtype",
] ]
...@@ -8,6 +8,8 @@ def to_torch_dtype(infini_dtype): ...@@ -8,6 +8,8 @@ def to_torch_dtype(infini_dtype):
return torch.float16 return torch.float16
elif infini_dtype == infinicore.float32: elif infini_dtype == infinicore.float32:
return torch.float32 return torch.float32
elif infini_dtype == infinicore.float64:
return torch.float64
elif infini_dtype == infinicore.bfloat16: elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16 return torch.bfloat16
elif infini_dtype == infinicore.int8: elif infini_dtype == infinicore.int8:
...@@ -22,6 +24,10 @@ def to_torch_dtype(infini_dtype): ...@@ -22,6 +24,10 @@ def to_torch_dtype(infini_dtype):
return torch.uint8 return torch.uint8
elif infini_dtype == infinicore.bool: elif infini_dtype == infinicore.bool:
return torch.bool return torch.bool
elif infini_dtype == infinicore.complex64:
return torch.complex64
elif infini_dtype == infinicore.complex128:
return torch.complex128
else: else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
...@@ -30,6 +36,8 @@ def to_infinicore_dtype(torch_dtype): ...@@ -30,6 +36,8 @@ def to_infinicore_dtype(torch_dtype):
"""Convert PyTorch data type to infinicore data type""" """Convert PyTorch data type to infinicore data type"""
if torch_dtype == torch.float32: if torch_dtype == torch.float32:
return infinicore.float32 return infinicore.float32
elif torch_dtype == torch.float64:
return infinicore.float64
elif torch_dtype == torch.float16: elif torch_dtype == torch.float16:
return infinicore.float16 return infinicore.float16
elif torch_dtype == torch.bfloat16: elif torch_dtype == torch.bfloat16:
...@@ -46,5 +54,9 @@ def to_infinicore_dtype(torch_dtype): ...@@ -46,5 +54,9 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.uint8 return infinicore.uint8
elif torch_dtype == torch.bool: elif torch_dtype == torch.bool:
return infinicore.bool return infinicore.bool
elif torch_dtype == torch.complex64:
return infinicore.complex64
elif torch_dtype == torch.complex128:
return infinicore.complex128
else: else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}") raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
import torch import torch
import math
from pathlib import Path from pathlib import Path
from .datatypes import to_torch_dtype from .datatypes import to_torch_dtype
from .devices import torch_device_map from .devices import torch_device_map
from .utils import is_integer_dtype from .utils import is_integer_dtype, is_complex_dtype
class TensorInitializer: class TensorInitializer:
...@@ -52,7 +53,12 @@ class TensorInitializer: ...@@ -52,7 +53,12 @@ class TensorInitializer:
return TensorInitializer._create_integer_tensor( return TensorInitializer._create_integer_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs shape, torch_dtype, torch_device_str, mode, **kwargs
) )
elif is_complex_dtype(torch_dtype):
return TensorInitializer._create_complex_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs
)
# Handle real floating-point types
if mode == TensorInitializer.RANDOM: if mode == TensorInitializer.RANDOM:
return torch.rand(shape, dtype=torch_dtype, device=torch_device_str) return torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS: elif mode == TensorInitializer.ZEROS:
...@@ -88,6 +94,7 @@ class TensorInitializer: ...@@ -88,6 +94,7 @@ class TensorInitializer:
@staticmethod @staticmethod
def _create_integer_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs): def _create_integer_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
"""Create integer tensor"""
if mode == TensorInitializer.RANDOM: if mode == TensorInitializer.RANDOM:
if torch_dtype == torch.bool: if torch_dtype == torch.bool:
return torch.randint( return torch.randint(
...@@ -135,6 +142,40 @@ class TensorInitializer: ...@@ -135,6 +142,40 @@ class TensorInitializer:
0, 100, shape, dtype=torch_dtype, device=torch_device_str 0, 100, shape, dtype=torch_dtype, device=torch_device_str
) )
@staticmethod
def _create_complex_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
"""Create complex tensor (complex64 or complex128)"""
if mode == TensorInitializer.RANDOM:
# Create complex tensor with random real and imaginary parts
real_part = torch.rand(shape, device=torch_device_str)
imag_part = torch.rand(shape, device=torch_device_str)
complex_tensor = torch.complex(real_part, imag_part)
return complex_tensor.to(torch_dtype)
elif mode == TensorInitializer.ZEROS:
return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.MANUAL:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Manual mode requires set_tensor")
if list(tensor.shape) != list(shape):
raise ValueError(
f"Shape mismatch: expected {shape}, got {tensor.shape}"
)
return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Binary mode requires set_tensor")
return tensor.to(torch_dtype).to(torch_device_str)
else:
# Default to random complex values
real_part = torch.rand(shape, device=torch_device_str)
imag_part = torch.rand(shape, device=torch_device_str)
complex_tensor = torch.complex(real_part, imag_part)
return complex_tensor.to(torch_dtype)
@staticmethod @staticmethod
def _create_strided_tensor( def _create_strided_tensor(
shape, strides, torch_dtype, torch_device_str, mode, **kwargs shape, strides, torch_dtype, torch_device_str, mode, **kwargs
......
...@@ -42,7 +42,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -42,7 +42,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
""" """
Debug function to compare two tensors and print differences Debug function to compare two tensors and print differences
""" """
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16: # Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32) actual = actual.to(torch.float32)
desired = desired.to(torch.float32) desired = desired.to(torch.float32)
...@@ -162,8 +166,6 @@ def convert_infinicore_to_torch(infini_result): ...@@ -162,8 +166,6 @@ def convert_infinicore_to_torch(infini_result):
Args: Args:
infini_result: infinicore tensor result infini_result: infinicore tensor result
dtype: infinicore data type
device_str: torch device string
Returns: Returns:
torch.Tensor: PyTorch tensor with infinicore data torch.Tensor: PyTorch tensor with infinicore data
...@@ -259,6 +261,24 @@ def compare_results( ...@@ -259,6 +261,24 @@ def compare_results(
if debug_mode and not result_equal: if debug_mode and not result_equal:
print("Integer tensor comparison failed - requiring exact equality") print("Integer tensor comparison failed - requiring exact equality")
return result_equal return result_equal
elif is_complex_dtype(torch_result_from_infini.dtype) or is_complex_dtype(
torch_result.dtype
):
# Complex number comparison - compare real and imaginary parts separately
real_close = torch.allclose(
torch_result_from_infini.real, torch_result.real, atol=atol, rtol=rtol
)
imag_close = torch.allclose(
torch_result_from_infini.imag, torch_result.imag, atol=atol, rtol=rtol
)
result_equal = real_close and imag_close
if debug_mode and not result_equal:
print("Complex tensor comparison failed")
if not real_close:
print(" Real parts don't match")
if not imag_close:
print(" Imaginary parts don't match")
return result_equal
else: else:
# Tolerance-based comparison for floating-point types # Tolerance-based comparison for floating-point types
return torch.allclose( return torch.allclose(
...@@ -382,3 +402,18 @@ def is_integer_dtype(dtype): ...@@ -382,3 +402,18 @@ def is_integer_dtype(dtype):
torch.uint8, torch.uint8,
torch.bool, torch.bool,
] ]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment