Commit 5b7ef9c5 authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/540 - support more dtypes in test framework

parent a5e20fcf
......@@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype):
return torch.float32
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
return torch.int8
elif infini_dtype == infinicore.int16:
return torch.int16
elif infini_dtype == infinicore.int32:
return torch.int32
elif infini_dtype == infinicore.int64:
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
......@@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.float16
elif torch_dtype == torch.bfloat16:
return infinicore.bfloat16
elif torch_dtype == torch.int8:
return infinicore.int8
elif torch_dtype == torch.int16:
return infinicore.int16
elif torch_dtype == torch.int32:
return infinicore.int32
elif torch_dtype == torch.int64:
return infinicore.int64
elif torch_dtype == torch.uint8:
return infinicore.uint8
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
import torch
import infinicore
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype
class TensorInitializer:
......@@ -38,6 +40,10 @@ class TensorInitializer:
torch_device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)
# Handle integer types differently for random initialization
if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype):
mode = TensorInitializer.RANDINT # Use randint for integer types
# Handle strided tensors - calculate required storage size
if strides is not None:
# Calculate the required storage size for strided tensor
......@@ -61,9 +67,22 @@ class TensorInitializer:
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000
base_tensor = torch.randint(
-2000000000,
2000000000,
low,
high,
(storage_size,),
dtype=torch_dtype,
device=torch_device_str,
......@@ -92,9 +111,22 @@ class TensorInitializer:
elif mode == TensorInitializer.ONES:
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000
tensor = torch.randint(
-2000000000,
2000000000,
low,
high,
shape,
dtype=torch_dtype,
device=torch_device_str,
......
......@@ -37,16 +37,43 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
print(f" {desc} time: {elapsed * 1000 :6f} ms")
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
]
def is_float_dtype(dtype):
"""Check if dtype is floating point type"""
return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def debug(
actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None
):
"""
Debug function to compare two tensors and print differences
"""
# Convert to float32 for bfloat16 comparison
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype)
# Use appropriate comparison based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
import numpy as np
np.testing.assert_array_equal(actual.cpu(), desired.cpu())
else:
# For float types, use allclose
import numpy as np
np.testing.assert_allclose(
......@@ -55,7 +82,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
......@@ -69,13 +96,21 @@ def print_discrepancy(
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
# Calculate difference mask based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, exact equality required
diff_mask = actual != expected
else:
# For float types, use tolerance-based comparison
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
actual_isnan ^ expected_isnan
if equal_nan
else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
......@@ -107,6 +142,7 @@ def print_discrepancy(
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
if not (dtype and is_integer_dtype(dtype)):
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(
......@@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
"""
Get tolerance settings based on data type
"""
# For integer types, return zero tolerance (exact match required)
if is_integer_dtype(tensor_dtype):
return 0, 0
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
......@@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
Args:
infini_result: infinicore tensor result
torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
device_str: torch device string
Returns:
torch.Tensor: PyTorch tensor with infinicore data
......@@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None
):
"""
Generic function to compare infinicore result with PyTorch reference result
......@@ -190,6 +228,7 @@ def compare_results(
atol: absolute tolerance
rtol: relative tolerance
debug_mode: whether to enable debug output
dtype: infinicore data type for comparison logic
Returns:
bool: True if results match within tolerance
......@@ -197,12 +236,21 @@ def compare_results(
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
# Choose comparison method based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
result = torch.equal(torch_result_from_infini, torch_result)
else:
# For float types, use tolerance-based comparison
result = torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
)
# Debug mode: detailed comparison
if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype)
# Check if results match within tolerance
return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
return result
def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
......@@ -227,7 +275,12 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results(
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
infini_result,
torch_result,
atol=atol,
rtol=rtol,
debug_mode=config.debug,
dtype=dtype,
)
return compare_test_results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment