Unverified Commit 39ec8f0e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #558 from InfiniTensor/issue/556

parents 2e5b2342 6b8949ce
......@@ -7,3 +7,5 @@ def add(input, other, *, out=None):
return Tensor(_infinicore.add(input._underlying, other._underlying))
_infinicore.add_(out._underlying, input._underlying, other._underlying)
return out
......@@ -24,3 +24,5 @@ def attention(q, k, v, k_cache, v_cache, pos, *, out=None):
v_cache._underlying,
pos,
)
return out
......@@ -7,3 +7,5 @@ def causal_softmax(input, *, out=None):
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
......@@ -7,3 +7,5 @@ def matmul(input, other, *, out=None):
return Tensor(_infinicore.matmul(input._underlying, other._underlying))
_infinicore.matmul_(out._underlying, input._underlying, other._underlying)
return out
......@@ -7,3 +7,5 @@ def rearrange(input, other, *, out=None):
return Tensor(_infinicore.rearrange(input._underlying))
_infinicore.rearrange_(out._underlying, input._underlying)
return out
......@@ -11,3 +11,5 @@ def rms_norm(input, weight, epsilon=1e-5, *, out=None):
_infinicore.rms_norm_(
out._underlying, input._underlying, weight._underlying, epsilon
)
return out
......@@ -7,3 +7,5 @@ def silu(input, *, out=None):
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
......@@ -7,3 +7,5 @@ def swiglu(input, other, *, out=None):
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
# [file name]: __init__.py
# [file content begin]
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
from .tensor import TensorSpec, TensorInitializer
from .utils import (
......@@ -16,7 +14,6 @@ from .config import get_test_devices, get_args
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .runner import GenericTestRunner
from .templates import BinaryOperatorTest, UnaryOperatorTest
__all__ = [
"TensorSpec",
......@@ -41,6 +38,4 @@ __all__ = [
"to_torch_dtype",
"to_infinicore_dtype",
"GenericTestRunner",
"BinaryOperatorTest",
"UnaryOperatorTest",
]
......@@ -2,7 +2,7 @@ import torch
import infinicore
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Union, Callable, Optional
from typing import List, Dict, Any, Optional
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
......@@ -11,28 +11,37 @@ from .utils import (
create_test_comparator,
infinicore_tensor_from_torch,
profile_operation,
rearrange_tensor,
synchronize_device,
convert_infinicore_to_torch,
)
class TestCase:
"""Test case"""
"""Test case with all configuration included"""
OUT_OF_PLACE = "out_of_place"
IN_PLACE = "in_place"
BOTH = "both"
def __init__(self, operation_mode, inputs, output=None, **kwargs):
if operation_mode not in [self.IN_PLACE, self.OUT_OF_PLACE, self.BOTH]:
raise ValueError(f"Invalid operation_mode: {operation_mode}")
if operation_mode == self.IN_PLACE and output is None:
raise ValueError("IN_PLACE mode requires output specification")
self.operation_mode = operation_mode
def __init__(
self,
inputs,
kwargs=None,
output_spec=None,
comparison_target=None,
description="",
tolerance=None,
):
"""
Initialize a test case with complete configuration
Args:
inputs: List of TensorSpec objects or scalars
kwargs: Additional keyword arguments for the operator
output_spec: TensorSpec for output tensor (for in-place operations)
comparison_target: Target for comparison ('out', index, or None for return value)
description: Test case description
tolerance: Tolerance settings for this test case {'atol': float, 'rtol': float}
"""
self.inputs = []
# Process inputs
for inp in inputs:
if isinstance(inp, (list, tuple)):
self.inputs.append(TensorSpec.from_tensor(inp))
......@@ -41,34 +50,34 @@ class TestCase:
else:
self.inputs.append(inp)
if isinstance(output, (list, tuple)):
self.output = TensorSpec.from_tensor(output)
else:
self.output = output
self.kwargs = kwargs or {}
self.output_spec = output_spec
self.comparison_target = comparison_target
self.description = description
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
self.kwargs = kwargs
self.description = kwargs.pop("description", "")
def get_tensor_input_count(self):
"""Count the number of tensor inputs (excluding scalars)"""
count = 0
for inp in self.inputs:
if isinstance(inp, TensorSpec) and not inp.is_scalar:
count += 1
return count
def __str__(self):
mode_str = self.operation_mode.upper()
input_strs = []
for inp in self.inputs:
if hasattr(inp, "is_scalar") and inp.is_scalar:
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
input_strs.append(f"scalar({inp.value}{dtype_str})")
elif hasattr(inp, "shape"):
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
dtype_str = f", {inp.dtype}" if inp.dtype else ""
init_str = (
f", init={inp.init_mode}"
if inp.init_mode != TensorInitializer.RANDOM
else ""
)
# Show shape and strides for non-contiguous tensors
if (
hasattr(inp, "is_contiguous")
and not inp.is_contiguous
and inp.strides
):
if hasattr(inp, "strides") and inp.strides:
strides_str = f", strides={inp.strides}"
input_strs.append(
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
......@@ -78,28 +87,38 @@ class TestCase:
else:
input_strs.append(str(inp))
base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]"
if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = (
f", init={self.output.init_mode}"
if self.output.init_mode != TensorInitializer.RANDOM
else ""
)
# Show shape and strides for non-contiguous output tensors
if (
hasattr(self.output, "is_contiguous")
and not self.output.is_contiguous
and self.output.strides
):
strides_str = f", strides={self.output.strides}"
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
else:
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs:
base_str += f", kwargs={self.kwargs}"
base_str = f"TestCase("
if self.description:
base_str += f", desc='{self.description}'"
base_str += f"{self.description}"
base_str += f" - inputs=[{', '.join(input_strs)}]"
if self.kwargs or self.output_spec:
kwargs_strs = []
for key, value in self.kwargs.items():
if key == "out" and isinstance(value, int):
kwargs_strs.append(f"{key}={value}")
else:
kwargs_strs.append(f"{key}={value}")
output_spec = self.output_spec
if output_spec and isinstance(output_spec, TensorSpec):
dtype_str = f", {output_spec.dtype}" if output_spec.dtype else ""
init_str = (
f", init={output_spec.init_mode}"
if output_spec.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(output_spec, "strides") and output_spec.strides:
strides_str = f", strides={output_spec.strides}"
kwargs_strs.append(
f"out=tensor{output_spec.shape}{strides_str}{dtype_str}{init_str}"
)
else:
kwargs_strs.append(
f"out=tensor{output_spec.shape}{dtype_str}{init_str}"
)
base_str += f", kwargs={{{', '.join(kwargs_strs)}}}"
base_str += ")"
return base_str
......@@ -107,23 +126,11 @@ class TestCase:
class TestConfig:
"""Test configuration"""
def __init__(
self,
tensor_dtypes,
tolerance_map,
debug=False,
bench=False,
num_prerun=10,
num_iterations=1000,
dtype_combinations=None,
):
self.tensor_dtypes = tensor_dtypes
self.tolerance_map = tolerance_map
def __init__(self, debug=False, bench=False, num_prerun=10, num_iterations=1000):
self.debug = debug
self.bench = bench
self.num_prerun = num_prerun
self.num_iterations = num_iterations
self.dtype_combinations = dtype_combinations
class TestRunner:
......@@ -140,58 +147,21 @@ class TestRunner:
print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
print(f"{'='*60}")
tensor_dtypes = self._filter_tensor_dtypes_by_device(
device, self.config.tensor_dtypes
)
for test_case in self.test_cases:
if self.config.dtype_combinations:
for dtype_combo in self.config.dtype_combinations:
try:
# Print test case info first
combo_str = self._format_dtype_combo(dtype_combo)
print(f"{test_case} with {combo_str}")
test_func(device, test_case, dtype_combo, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
combo_str = self._format_dtype_combo(dtype_combo)
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
else:
for dtype in tensor_dtypes:
try:
# Print test case info first
print(f"{test_case} with {dtype}")
test_func(device, test_case, dtype, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
try:
print(f"{test_case}")
test_func(device, test_case, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
return len(self.failed_tests) == 0
def _format_dtype_combo(self, dtype_combo):
if isinstance(dtype_combo, dict):
return f"dtypes({dtype_combo})"
elif isinstance(dtype_combo, (list, tuple)):
return f"dtypes{tuple(dtype_combo)}"
else:
return str(dtype_combo)
def _filter_tensor_dtypes_by_device(self, device, tensor_dtypes):
if device in ():
return [dt for dt in tensor_dtypes if dt != infinicore.bfloat16]
else:
return tensor_dtypes
def print_summary(self):
if self.failed_tests:
print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m")
......@@ -209,120 +179,100 @@ class BaseOperatorTest(ABC):
def __init__(self, operator_name):
self.operator_name = operator_name
self.test_cases = self.get_test_cases()
self.tensor_dtypes = self.get_tensor_dtypes()
self.tolerance_map = self.get_tolerance_map()
self.dtype_combinations = self.get_dtype_combinations()
@abstractmethod
def get_test_cases(self):
"""Return list of TestCase objects"""
"""Return list of TestCase objects with complete configuration"""
pass
@abstractmethod
def get_tensor_dtypes(self):
"""Return supported data types"""
pass
@abstractmethod
def get_tolerance_map(self):
"""Return tolerance configuration"""
pass
def get_dtype_combinations(self):
"""Return dtype combinations for mixed dtype tests"""
return None
def torch_operator(self, *inputs, out=None, **kwargs):
"""Unified PyTorch operator function - can be overridden or return None"""
def torch_operator(self, *args, **kwargs):
"""PyTorch operator function"""
raise NotImplementedError("torch_operator not implemented")
def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified InfiniCore operator function - can be overridden or return None"""
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore operator function"""
raise NotImplementedError("infinicore_operator not implemented")
def create_strided_tensor(
self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM
):
"""Create a non-contiguous tensor with specific strides"""
spec = TensorSpec.from_strided_tensor(shape, strides, dtype, init_mode)
return spec.create_torch_tensor(device, dtype)
def prepare_inputs(self, test_case, device, dtype_config):
"""Prepare input data"""
def prepare_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
inputs = []
kwargs = test_case.kwargs.copy()
# Prepare input tensors
for i, input_spec in enumerate(test_case.inputs):
if isinstance(input_spec, TensorSpec):
if input_spec.is_scalar:
inputs.append(input_spec.value)
else:
tensor = input_spec.create_torch_tensor(device, dtype_config, i)
tensor = input_spec.create_torch_tensor(device)
inputs.append(tensor)
else:
inputs.append(input_spec)
return inputs, test_case.kwargs
# Prepare output tensor if specified in output_spec
if test_case.output_spec is not None:
output_tensor = test_case.output_spec.create_torch_tensor(device)
kwargs["out"] = output_tensor
def get_output_dtype(self, test_case, dtype_config, torch_result=None):
"""Determine output dtype - returns infinicore dtype, not torch dtype"""
if test_case.output and test_case.output.dtype is not None:
return test_case.output.dtype
elif isinstance(dtype_config, dict) and "output" in dtype_config:
return dtype_config["output"]
elif torch_result is not None:
return to_infinicore_dtype(torch_result.dtype)
else:
if isinstance(dtype_config, (list, tuple)):
return dtype_config[0]
# Handle integer indices for in-place operations
if "out" in kwargs and isinstance(kwargs["out"], int):
input_idx = kwargs["out"]
if 0 <= input_idx < len(inputs) and isinstance(
inputs[input_idx], torch.Tensor
):
kwargs["out"] = inputs[input_idx]
else:
return dtype_config
def run_test(self, device, test_case, dtype_config, config):
"""Unified test execution flow"""
device_str = torch_device_map[device]
if test_case.operation_mode == TestCase.BOTH:
out_of_place_case = TestCase(
TestCase.OUT_OF_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
)
self._run_single_test(
device, out_of_place_case, dtype_config, config, "OUT_OF_PLACE"
)
if test_case.output is not None:
in_place_case = TestCase(
TestCase.IN_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
raise ValueError(
f"Invalid input index for in-place operation: {input_idx}"
)
self._run_single_test(
device, in_place_case, dtype_config, config, "IN_PLACE"
)
return
self._run_single_test(
device, test_case, dtype_config, config, test_case.operation_mode.upper()
)
return inputs, kwargs
def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
"""Run a single test with specified operation mode"""
def run_test(self, device, test_case, config):
"""Unified test execution flow"""
device_str = torch_device_map[device]
inputs, kwargs = self.prepare_inputs(test_case, device, dtype_config)
# Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# For in-place operations on input tensors, we need to preserve the original state
original_inputs = []
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
# This is an in-place operation on an input tensor
# Store original values for comparison
for inp in inputs:
if isinstance(inp, torch.Tensor):
original_inputs.append(inp.clone().detach())
else:
original_inputs.append(inp)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs = []
torch_input_clones = []
for inp in inputs:
if isinstance(inp, torch.Tensor):
infini_tensor = infinicore_tensor_from_torch(inp)
cloned_inp = inp.clone().detach()
torch_input_clones.append(cloned_inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
# Check if operators are implemented
# Determine comparison target
comparison_target = test_case.comparison_target
# Handle infinicore output
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
if isinstance(comparison_target, int):
infini_kwargs["out"] = infini_inputs[comparison_target]
else:
cloned_out = infini_kwargs["out"].clone().detach()
torch_input_clones.append(cloned_out)
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
# Check operator implementations
torch_implemented = True
infini_implemented = True
......@@ -335,19 +285,19 @@ class BaseOperatorTest(ABC):
torch_result = None
try:
infini_result = self.infinicore_operator(*infini_inputs, **kwargs)
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
if infini_result is None:
infini_implemented = False
except NotImplementedError:
infini_implemented = False
infini_result = None
# If neither operator is implemented, skip the test
# Skip if neither operator is implemented
if not torch_implemented and not infini_implemented:
print(f"⚠ Both operators not implemented - test skipped")
return
# If only one operator is implemented, run it without comparison
# Single operator execution without comparison
if not torch_implemented or not infini_implemented:
missing_op = (
"torch_operator" if not torch_implemented else "infinicore_operator"
......@@ -356,14 +306,12 @@ class BaseOperatorTest(ABC):
f"⚠ {missing_op} not implemented - running single operator without comparison"
)
# Run the available operator for benchmarking if requested
if config.bench:
if torch_implemented:
def torch_op():
return self.torch_operator(*inputs, **kwargs)
print(f" {mode_name}:")
profile_operation(
"PyTorch ",
torch_op,
......@@ -374,9 +322,8 @@ class BaseOperatorTest(ABC):
if infini_implemented:
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)
return self.infinicore_operator(*infini_inputs, **infini_kwargs)
print(f" {mode_name}:")
profile_operation(
"InfiniCore",
infini_op,
......@@ -386,126 +333,79 @@ class BaseOperatorTest(ABC):
)
return
# Both operators are implemented - proceed with normal comparison
if test_case.operation_mode == TestCase.OUT_OF_PLACE:
def torch_op():
return self.torch_operator(*inputs, **kwargs)
torch_result = torch_op()
if (
isinstance(torch_result, torch.Tensor)
and not torch_result.is_contiguous()
):
torch_result = torch_result.contiguous()
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)
infini_result = infini_op()
# Get comparison dtype (infinicore dtype)
comparison_dtype = self.get_output_dtype(
test_case, dtype_config, torch_result
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
if comparison_target is None:
# Compare return values (out-of-place)
torch_comparison = torch_result
infini_comparison = infini_result
elif comparison_target == "out":
# Compare output tensor from kwargs (explicit output)
torch_comparison = kwargs.get("out")
infini_comparison = infini_kwargs.get("out")
elif isinstance(comparison_target, int):
# Compare specific input tensor (in-place operation on input)
# For in-place operations, we compare the modified input tensor
if 0 <= comparison_target < len(inputs):
torch_comparison = inputs[comparison_target]
infini_comparison = infini_inputs[comparison_target]
else:
raise ValueError(
f"Invalid comparison target index: {comparison_target}"
)
else:
if not test_case.output:
raise ValueError("IN_PLACE test requires output specification")
raise ValueError(f"Invalid comparison target: {comparison_target}")
# Get output dtype and create output tensor
output_dtype = self.get_output_dtype(test_case, dtype_config)
output_shape = test_case.output.shape
# Validate comparison targets
if torch_comparison is None or infini_comparison is None:
raise ValueError("Comparison targets cannot be None")
# Use TensorSpec to create output tensor with specified initialization mode
if test_case.output.is_contiguous or test_case.output.strides is None:
output_spec = TensorSpec.from_tensor(
output_shape, output_dtype, init_mode=test_case.output.init_mode
)
else:
output_spec = TensorSpec.from_strided_tensor(
output_shape,
test_case.output.strides,
output_dtype,
init_mode=test_case.output.init_mode,
)
# Perform comparison
atol = test_case.tolerance.get("atol", 1e-5)
rtol = test_case.tolerance.get("rtol", 1e-3)
torch_output = output_spec.create_torch_tensor(device, output_dtype)
compare_fn = create_test_comparator(config, atol, rtol, test_case.description)
# For non-contiguous tensors, we need to ensure zeros initialization
if (
not test_case.output.is_contiguous
and test_case.output.strides is not None
):
torch_output.zero_()
is_valid = compare_fn(infini_comparison, torch_comparison)
assert is_valid, f"Result comparison failed for {test_case}"
def torch_op_inplace():
self.torch_operator(*inputs, out=torch_output, **kwargs)
# Benchmarking
if config.bench:
if comparison_target is None:
# Out-of-place benchmarking
def torch_op():
return self.torch_operator(*inputs, **kwargs)
torch_op_inplace()
def infini_op():
return self.infinicore_operator(*infini_inputs, **infini_kwargs)
# Create infinicore output tensor
torch_dummy = torch.zeros(
output_shape, dtype=to_torch_dtype(output_dtype), device=device_str
)
if (
not test_case.output.is_contiguous
and not test_case.output.strides is None
):
rearrange_tensor(torch_dummy, list(torch_output.stride()))
infini_output = infinicore_tensor_from_torch(torch_dummy)
def infini_op_inplace():
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)
else:
# In-place benchmarking
def torch_op():
self.torch_operator(*inputs, **kwargs)
return (
kwargs.get("out")
if "out" in kwargs
else inputs[comparison_target]
)
infini_op_inplace()
def infini_op():
self.infinicore_operator(*infini_inputs, **infini_kwargs)
return (
infini_kwargs.get("out")
if "out" in infini_kwargs
else infini_inputs[comparison_target]
)
comparison_dtype = self.get_output_dtype(
test_case, dtype_config, torch_output
profile_operation(
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{mode_name}"
profile_operation(
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
)
is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
"PyTorch ",
torch_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"InfiniCore",
infini_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
......@@ -20,6 +20,8 @@ def to_torch_dtype(infini_dtype):
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
elif infini_dtype == infinicore.bool:
return torch.bool
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
......@@ -42,5 +44,7 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.int64
elif torch_dtype == torch.uint8:
return infinicore.uint8
elif torch_dtype == torch.bool:
return infinicore.bool
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
......@@ -20,13 +20,10 @@ class GenericTestRunner:
def run(self):
"""Execute the complete test suite"""
config = TestConfig(
tensor_dtypes=self.operator_test.tensor_dtypes,
tolerance_map=self.operator_test.tolerance_map,
debug=self.args.debug,
bench=self.args.bench,
num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations,
dtype_combinations=self.operator_test.dtype_combinations,
)
runner = TestRunner(self.operator_test.test_cases, config)
......
import torch
import infinicore
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
......@@ -18,150 +17,147 @@ class TensorInitializer:
FROM_FILE = "from_file"
@staticmethod
def create_tensor(
shape, dtype, device, mode=RANDOM, strides=None, set_tensor=None, file_path=None
):
def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, **kwargs):
"""
Create a torch tensor with specified initialization mode
Unified tensor creation interface for both contiguous and non-contiguous tensors
Args:
shape: Tensor shape
dtype: infinicore dtype
device: InfiniDeviceEnum
mode: Initialization mode
strides: Optional strides for strided tensors
set_tensor: Pre-existing tensor for manual/binary mode
file_path: Path to file for FROM_FILE mode
strides: Optional strides for non-contiguous tensors
**kwargs: Additional arguments for specific modes
Returns:
torch.Tensor: Initialized tensor
"""
# Convert InfiniDeviceEnum to torch device string
torch_device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)
# Handle integer types differently for random initialization
if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype):
mode = TensorInitializer.RANDINT # Use randint for integer types
# Handle strided tensors - calculate required storage size
# Handle non-contiguous tensors
if strides is not None:
# Calculate the required storage size for strided tensor
storage_size = 0
for i in range(len(shape)):
if shape[i] > 0:
storage_size += (shape[i] - 1) * abs(strides[i])
storage_size += 1 # Add 1 for the base element
# Create base storage with sufficient size
if mode == TensorInitializer.RANDOM:
base_tensor = torch.rand(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.ZEROS:
base_tensor = torch.zeros(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.ONES:
base_tensor = torch.ones(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000
base_tensor = torch.randint(
low,
high,
(storage_size,),
dtype=torch_dtype,
device=torch_device_str,
return TensorInitializer._create_strided_tensor(
shape, strides, torch_dtype, torch_device_str, mode, **kwargs
)
else:
return TensorInitializer._create_contiguous_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs
)
@staticmethod
def _create_contiguous_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
"""Create contiguous tensor"""
if is_integer_dtype(torch_dtype):
return TensorInitializer._create_integer_tensor(
shape, torch_dtype, torch_device_str, mode, **kwargs
)
if mode == TensorInitializer.RANDOM:
return torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS:
return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
low = kwargs.get("low", -2000000000)
high = kwargs.get("high", 2000000000)
return torch.randint(
low, high, shape, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.MANUAL:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Manual mode requires set_tensor")
if list(tensor.shape) != list(shape):
raise ValueError(
f"Shape mismatch: expected {shape}, got {tensor.shape}"
)
elif mode == TensorInitializer.MANUAL:
assert set_tensor is not None, "Manual mode requires set_tensor"
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
assert set_tensor is not None, "Binary mode requires set_tensor"
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE:
base_tensor = TensorInitializer._load_from_file(
file_path, storage_size, torch_dtype, torch_device_str
return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Binary mode requires set_tensor")
return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE:
return TensorInitializer._load_from_file(
kwargs.get("file_path"), shape, torch_dtype, torch_device_str
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
@staticmethod
def _create_integer_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
if mode == TensorInitializer.RANDOM:
if torch_dtype == torch.bool:
return torch.randint(
0, 2, shape, dtype=torch_dtype, device=torch_device_str
).bool()
elif torch_dtype == torch.uint8:
return torch.randint(
0, 256, shape, dtype=torch_dtype, device=torch_device_str
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
# Create strided view
tensor = torch.as_strided(base_tensor, shape, strides)
else:
# Contiguous tensor
if mode == TensorInitializer.RANDOM:
tensor = torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS:
tensor = torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000
tensor = torch.randint(
low,
high,
dtype_info = torch.iinfo(torch_dtype)
return torch.randint(
dtype_info.min,
dtype_info.max,
shape,
dtype=torch_dtype,
device=torch_device_str,
)
elif mode == TensorInitializer.MANUAL:
assert set_tensor is not None, "Manual mode requires set_tensor"
assert shape == list(set_tensor.shape), "Shape mismatch in manual mode"
tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
assert set_tensor is not None, "Binary mode requires set_tensor"
assert shape == list(set_tensor.shape), "Shape mismatch in binary mode"
tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE:
tensor = TensorInitializer._load_from_file(
file_path, shape, torch_dtype, torch_device_str
elif mode == TensorInitializer.ZEROS:
return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
low = kwargs.get("low", -100)
high = kwargs.get("high", 100)
return torch.randint(
low, high, shape, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.MANUAL:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Manual mode requires set_tensor")
if list(tensor.shape) != list(shape):
raise ValueError(
f"Shape mismatch: expected {shape}, got {tensor.shape}"
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
return tensor
return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Binary mode requires set_tensor")
return tensor.to(torch_dtype).to(torch_device_str)
else:
return torch.randint(
0, 100, shape, dtype=torch_dtype, device=torch_device_str
)
@staticmethod
def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str):
"""
Load tensor data from file using PyTorch's native methods
def _create_strided_tensor(
shape, strides, torch_dtype, torch_device_str, mode, **kwargs
):
"""Create non-contiguous tensor with specific strides"""
# Calculate required storage size
storage_size = 0
for i in range(len(shape)):
if shape[i] > 0:
storage_size += (shape[i] - 1) * abs(strides[i])
storage_size += 1
# Create base storage
base_tensor = TensorInitializer._create_contiguous_tensor(
(storage_size,), torch_dtype, torch_device_str, mode, **kwargs
)
Args:
file_path: Path to the file
shape_or_size: Tensor shape for contiguous or size for strided
torch_dtype: Target torch dtype
torch_device_str: Target device string
# Create strided view
return torch.as_strided(base_tensor, shape, strides)
Returns:
torch.Tensor: Tensor with data loaded from file
"""
@staticmethod
def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str):
"""Load tensor data from file"""
if file_path is None:
raise ValueError("FROM_FILE mode requires file_path")
......@@ -169,21 +165,15 @@ class TensorInitializer:
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Determine file type and load accordingly
file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]:
# PyTorch native format
tensor = torch.load(file_path, map_location=torch_device_str)
elif file_extension in [".bin", ".dat", ".raw"]:
# Raw binary format - we need to know the expected shape
tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str
)
elif file_extension in [".npy"]:
# NumPy format - fallback to numpy if needed
try:
import numpy as np
......@@ -193,125 +183,48 @@ class TensorInitializer:
)
except ImportError:
raise ImportError("NumPy is required to load .npy files")
else:
# Try to load as PyTorch format first, then fallback to binary
try:
tensor = torch.load(file_path, map_location=torch_device_str)
except:
# Fallback to binary loading
tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str
)
# Ensure correct dtype and device
tensor = tensor.to(torch_dtype).to(torch_device_str)
# Validate shape/size
if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor - check shape
if list(tensor.shape) != list(shape_or_size):
raise ValueError(
f"Tensor shape mismatch: expected {shape_or_size}, got {tensor.shape}"
f"Shape mismatch: expected {shape_or_size}, got {tensor.shape}"
)
else:
# Strided tensor - check total size
if tensor.numel() != shape_or_size:
raise ValueError(
f"Tensor size mismatch: expected {shape_or_size} elements, got {tensor.numel()}"
f"Size mismatch: expected {shape_or_size} elements, got {tensor.numel()}"
)
return tensor
@staticmethod
def _load_binary_file(file_path, shape_or_size, torch_dtype, torch_device_str):
"""
Load tensor from raw binary file
Args:
file_path: Path to binary file
shape_or_size: Expected shape or size
torch_dtype: Target dtype
torch_device_str: Target device
Returns:
torch.Tensor: Loaded tensor
"""
# Read binary data
"""Load tensor from raw binary file"""
with open(file_path, "rb") as f:
binary_data = f.read()
# Create tensor from buffer
if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor with known shape
tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape(
shape_or_size
)
else:
# Strided tensor - just 1D buffer
tensor = torch.frombuffer(binary_data, dtype=torch_dtype)
return tensor.to(torch_device_str)
@staticmethod
def save_to_file(tensor, file_path, format="auto"):
"""
Save tensor data to file using PyTorch's native methods
Args:
tensor: torch.Tensor to save
file_path: Path to save the file
format: File format ('auto', 'torch', 'binary', 'numpy')
"""
file_path = Path(file_path)
if format == "auto":
# Determine format from file extension
file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]:
format = "torch"
elif file_extension in [".npy"]:
format = "numpy"
else:
format = "binary"
if format == "torch":
# PyTorch native format (preserves metadata)
torch.save(tensor, file_path)
elif format == "binary":
# Raw binary format
with open(file_path, "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
elif format == "numpy":
# NumPy format
try:
import numpy as np
np.save(file_path, tensor.cpu().numpy())
except ImportError:
raise ImportError("NumPy is required to save .npy files")
else:
raise ValueError(f"Unsupported format: {format}")
print(
f"Tensor saved to {file_path} (shape: {tensor.shape}, dtype: {tensor.dtype}, format: {format})"
)
@staticmethod
def list_supported_formats():
"""Return list of supported file formats"""
return {
"torch": [".pt", ".pth"], # PyTorch native format
"binary": [".bin", ".dat", ".raw"], # Raw binary
"numpy": [".npy"], # NumPy format
}
class TensorSpec:
"""Tensor specification supporting various input types and per-tensor dtype"""
"""Unified tensor specification for both contiguous and non-contiguous tensors"""
def __init__(
self,
......@@ -320,43 +233,34 @@ class TensorSpec:
strides=None,
value=None,
is_scalar=False,
is_contiguous=True,
init_mode=TensorInitializer.RANDOM, # Default to random initialization
custom_tensor=None, # For manual/binary mode
file_path=None, # For FROM_FILE mode
file_format=None, # Optional file format hint
init_mode=TensorInitializer.RANDOM,
**kwargs,
):
self.shape = shape
self.dtype = dtype
self.strides = strides
self.value = value
self.is_scalar = is_scalar
self.is_contiguous = is_contiguous
self.init_mode = init_mode
self.custom_tensor = custom_tensor
self.file_path = file_path
self.file_format = file_format
self.kwargs = kwargs
@classmethod
def from_tensor(
cls,
shape,
dtype=None,
strides=None,
is_contiguous=True,
dtype=None,
init_mode=TensorInitializer.RANDOM,
custom_tensor=None,
file_path=None,
**kwargs,
):
"""Create tensor specification - unified interface for both contiguous and non-contiguous"""
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=is_contiguous,
init_mode=init_mode,
custom_tensor=custom_tensor,
file_path=file_path,
**kwargs,
)
@classmethod
......@@ -365,84 +269,46 @@ class TensorSpec:
@classmethod
def from_strided_tensor(
cls,
shape,
strides,
dtype=None,
init_mode=TensorInitializer.RANDOM,
custom_tensor=None,
file_path=None,
):
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=False,
init_mode=init_mode,
custom_tensor=custom_tensor,
file_path=file_path,
)
@classmethod
def from_file(
cls,
file_path,
shape,
dtype=None,
strides=None,
is_contiguous=True,
file_format=None,
cls, shape, strides, dtype=None, init_mode=TensorInitializer.RANDOM, **kwargs
):
"""
Create TensorSpec that loads data from file
"""Alias for from_tensor with explicit strides (for backward compatibility)"""
return cls.from_tensor(shape, strides, dtype, init_mode, **kwargs)
Args:
file_path: Path to file
shape: Tensor shape
dtype: infinicore dtype (inferred from file if None)
strides: Optional strides for strided tensors
is_contiguous: Whether tensor is contiguous
file_format: Optional file format hint
Returns:
TensorSpec: Configured for file loading
"""
return cls(
shape=shape,
def with_dtype(self, dtype):
"""Create a new TensorSpec with the specified dtype"""
return TensorSpec(
shape=self.shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=is_contiguous,
init_mode=TensorInitializer.FROM_FILE,
file_path=file_path,
file_format=file_format,
strides=self.strides,
value=self.value,
is_scalar=self.is_scalar,
init_mode=self.init_mode,
**self.kwargs,
)
def create_torch_tensor(self, device, dtype_config, tensor_index=0):
def create_torch_tensor(self, device):
"""Create a torch tensor based on this specification"""
if self.is_scalar:
return self.value
# Determine dtype - ensure we're using infinicore dtype, not torch dtype
if self.dtype is not None:
tensor_dtype = self.dtype
elif isinstance(dtype_config, dict) and f"input_{tensor_index}" in dtype_config:
tensor_dtype = dtype_config[f"input_{tensor_index}"]
elif isinstance(dtype_config, (list, tuple)) and tensor_index < len(
dtype_config
):
tensor_dtype = dtype_config[tensor_index]
else:
tensor_dtype = dtype_config
# Create tensor using the specified initialization mode
# Create tensor using unified interface
return TensorInitializer.create_tensor(
shape=self.shape,
dtype=tensor_dtype,
dtype=self.dtype, # Use the dtype from the spec
device=device,
mode=self.init_mode,
strides=self.strides,
set_tensor=self.custom_tensor,
file_path=self.file_path,
**self.kwargs,
)
def is_tensor_input(self):
"""Check if this spec represents a tensor input (not scalar)"""
return not self.is_scalar
def __str__(self):
if self.is_scalar:
return f"scalar({self.value})"
else:
strides_str = f", strides={self.strides}" if self.strides else ""
dtype_str = f", dtype={self.dtype}" if self.dtype else ""
return f"tensor{self.shape}{strides_str}{dtype_str}"
......@@ -37,52 +37,25 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
print(f" {desc} time: {elapsed * 1000 :6f} ms")
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
]
def is_float_dtype(dtype):
"""Check if dtype is floating point type"""
return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def debug(
actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None
):
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Convert to float32 for bfloat16 comparison
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
# Use appropriate comparison based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
import numpy as np
import numpy as np
np.testing.assert_array_equal(actual.cpu(), desired.cpu())
else:
# For float types, use allclose
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
......@@ -96,21 +69,13 @@ def print_discrepancy(
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, exact equality required
diff_mask = actual != expected
else:
# For float types, use tolerance-based comparison
nan_mismatch = (
actual_isnan ^ expected_isnan
if equal_nan
else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
......@@ -142,9 +107,8 @@ def print_discrepancy(
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
if not (dtype and is_integer_dtype(dtype)):
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
......@@ -166,10 +130,6 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
"""
Get tolerance settings based on data type
"""
# For integer types, return zero tolerance (exact match required)
if is_integer_dtype(tensor_dtype):
return 0, 0
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
......@@ -202,6 +162,8 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
Args:
infini_result: infinicore tensor result
torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
device_str: torch device string
Returns:
torch.Tensor: PyTorch tensor with infinicore data
......@@ -217,70 +179,103 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
):
"""
Generic function to compare infinicore result with PyTorch reference result
Supports both floating-point (with tolerance) and integer (exact) comparison
Args:
infini_result: infinicore tensor result
torch_result: PyTorch tensor reference result
atol: absolute tolerance
rtol: relative tolerance
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
debug_mode: whether to enable debug output
dtype: infinicore data type for comparison logic
Returns:
bool: True if results match within tolerance
bool: True if results match within tolerance (FP) or exactly (integer)
"""
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
# Choose comparison method based on dtype
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
result = torch.equal(torch_result_from_infini, torch_result)
else:
# For float types, use tolerance-based comparison
result = torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
)
# Handle scalar integer comparison
if isinstance(torch_result_from_infini, (int, float)) and isinstance(
torch_result, (int, float)
):
if isinstance(torch_result_from_infini, int) and isinstance(torch_result, int):
# Exact integer scalar comparison
result_equal = torch_result_from_infini == torch_result
if debug_mode and not result_equal:
print(
f"Integer scalar mismatch: {torch_result_from_infini} != {torch_result}"
)
return result_equal
else:
# Floating-point scalar comparison with tolerance
return abs(torch_result_from_infini - torch_result) <= atol + rtol * abs(
torch_result
)
# Debug mode: detailed comparison
if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype)
return result
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
# Choose comparison method based on data type
if is_integer_dtype(torch_result_from_infini.dtype) or is_integer_dtype(
torch_result.dtype
):
# Exact equality for integer types
result_equal = torch.equal(torch_result_from_infini, torch_result)
if debug_mode and not result_equal:
print("Integer tensor comparison failed - requiring exact equality")
return result_equal
else:
# Tolerance-based comparison for floating-point types
return torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
)
def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
def create_test_comparator(config, atol, rtol, mode_name=""):
"""
Create a test-specific comparison function that handles test configuration
Create a test-specific comparison function
Args:
config: test configuration
dtype: infinicore data type
tolerance_map: optional tolerance map (defaults to config's tolerance_map)
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
mode_name: operation mode name for debug output
Returns:
callable: function that takes (infini_result, torch_result) and returns bool
"""
if tolerance_map is None:
tolerance_map = config.tolerance_map
atol, rtol = get_tolerance(tolerance_map, dtype)
def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
# For integer types, override tolerance to require exact equality
actual_atol = atol
actual_rtol = rtol
# Check if we're dealing with integer types
try:
# Try to get dtype from infinicore tensor
if hasattr(infini_result, "dtype"):
infini_dtype = infini_result.dtype
torch_dtype = to_torch_dtype(infini_dtype)
if is_integer_dtype(torch_dtype):
actual_atol = 0
actual_rtol = 0
except:
pass
return compare_results(
infini_result,
torch_result,
atol=atol,
rtol=rtol,
atol=actual_atol,
rtol=actual_rtol,
debug_mode=config.debug,
dtype=dtype,
)
return compare_test_results
......@@ -330,3 +325,30 @@ def rearrange_tensor(tensor, new_strides):
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
......@@ -7,93 +7,138 @@ import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides)
# Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
(TestCase.BOTH, (13, 4), None, None, None),
(TestCase.BOTH, (13, 4), (10, 1), (10, 1), (10, 1)),
(TestCase.BOTH, (13, 4), (0, 1), None, None),
(TestCase.BOTH, (13, 4, 4), None, None, None),
(TestCase.BOTH, (13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
(TestCase.BOTH, (13, 4, 4), (4, 0, 1), (0, 4, 1), None),
(TestCase.BOTH, (16, 5632), None, None, None),
(TestCase.BOTH, (16, 5632), (13312, 1), (13312, 1), (13312, 1)),
# Basic cases
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), None),
# Strided cases
((13, 4), None, None, (10, 1)),
((13, 4), (10, 1), (10, 1), (10, 1)),
# 3D cases
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), None),
# Broadcast cases
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
# Large tensors
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), None),
]
def parse_test_cases(data):
"""
Parse add test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
shape = data[1]
a_strides = data[2] if len(data) > 2 else None
b_strides = data[3] if len(data) > 3 else None
c_strides = data[4] if len(data) > 4 else None
# Create input specifications
inputs = []
# Input tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Input tensor b (same shape as a)
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(shape, c_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Add - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (add(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Add - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (add(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"Add - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (add(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"Add - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Add test with simplified test case parsing"""
"""Add operator test with simplified implementation"""
def __init__(self):
super().__init__("Add")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
return parse_test_cases()
def torch_operator(self, a, b, out=None, **kwargs):
return torch.add(a, b, out=out)
def torch_operator(self, *args, **kwargs):
"""PyTorch add implementation"""
return torch.add(*args, **kwargs)
def infinicore_operator(self, a, b, out=None, **kwargs):
return infinicore.add(a, b, out=out)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore add implementation"""
return infinicore.add(*args, **kwargs)
def main():
......
......@@ -11,18 +11,17 @@ import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos,
# k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides,
# k_cache_strides, v_cache_strides)
# Test cases format: (n_q_head, n_kv_head, seq_len, head_dim, pos, k_cache_buf_len, v_cache_buf_len,
# q_strides, k_strides, v_strides, k_cache_strides, v_cache_strides)
_TEST_CASES_DATA = [
# Prefill stage
(
TestCase.OUT_OF_PLACE,
32,
4,
5,
......@@ -38,7 +37,6 @@ _TEST_CASES_DATA = [
),
# Decode stage
(
TestCase.OUT_OF_PLACE,
32,
4,
1,
......@@ -53,10 +51,9 @@ _TEST_CASES_DATA = [
[64, 11264, 1],
),
# Small test case
(TestCase.OUT_OF_PLACE, 8, 4, 2, 16, 1, 8, 8, None, None, None, None, None),
(8, 4, 2, 16, 1, 8, 8, None, None, None, None, None),
# Another prefill case
(
TestCase.OUT_OF_PLACE,
28,
28,
15,
......@@ -137,124 +134,114 @@ def torch_attention(q, k, v, k_cache, v_cache, pos):
return attn_output
def parse_test_cases(data):
def parse_test_cases():
"""
Parse attention test case data according to format:
(operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos,
k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides,
k_cache_strides, v_cache_strides)
(n_q_head, n_kv_head, seq_len, head_dim, pos, k_cache_buf_len, v_cache_buf_len,
q_strides, k_strides, v_strides, k_cache_strides, v_cache_strides)
"""
operation_mode = data[0]
n_q_head, n_kv_head, seq_len, head_dim, pos = (
data[1],
data[2],
data[3],
data[4],
data[5],
)
k_cache_buf_len, v_cache_buf_len = data[6], data[7]
q_strides = data[8] if len(data) > 8 else None
k_strides = data[9] if len(data) > 9 else None
v_strides = data[10] if len(data) > 10 else None
k_cache_strides = data[11] if len(data) > 11 else None
v_cache_strides = data[12] if len(data) > 12 else None
# Create input specifications
inputs = []
# Query tensor: (n_q_head, seq_len, head_dim)
if q_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_q_head, seq_len, head_dim), q_strides)
test_cases = []
for data in _TEST_CASES_DATA:
n_q_head, n_kv_head, seq_len, head_dim, pos = (
data[0],
data[1],
data[2],
data[3],
data[4],
)
else:
inputs.append(TensorSpec.from_tensor((n_q_head, seq_len, head_dim)))
# Key tensor: (n_kv_head, seq_len, head_dim)
if k_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), k_strides)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim)))
# Value tensor: (n_kv_head, seq_len, head_dim)
if v_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), v_strides)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim)))
# Key cache: (n_kv_head, k_cache_buf_len, head_dim)
if k_cache_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor(
(n_kv_head, k_cache_buf_len, head_dim), k_cache_strides
k_cache_buf_len, v_cache_buf_len = data[5], data[6]
q_strides = data[7] if len(data) > 7 else None
k_strides = data[8] if len(data) > 8 else None
v_strides = data[9] if len(data) > 9 else None
k_cache_strides = data[10] if len(data) > 10 else None
v_cache_strides = data[11] if len(data) > 11 else None
# Check if output tensor supports in-place operations
# For attention, output shape is (seq_len, n_q_head, head_dim)
output_shape = (seq_len, n_q_head, head_dim)
output_supports_inplace = True # Output is always contiguous for attention
# Generate test cases for all data types
for dtype in [infinicore.float16, infinicore.bfloat16, infinicore.float32]:
tolerance = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-3, "rtol": 5e-2},
}.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Create typed tensor specs
q_spec = TensorSpec.from_tensor(
(n_q_head, seq_len, head_dim), q_strides, dtype
)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, k_cache_buf_len, head_dim)))
# Value cache: (n_kv_head, v_cache_buf_len, head_dim)
if v_cache_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor(
(n_kv_head, v_cache_buf_len, head_dim), v_cache_strides
k_spec = TensorSpec.from_tensor(
(n_kv_head, seq_len, head_dim), k_strides, dtype
)
v_spec = TensorSpec.from_tensor(
(n_kv_head, seq_len, head_dim), v_strides, dtype
)
k_cache_spec = TensorSpec.from_tensor(
(n_kv_head, k_cache_buf_len, head_dim), k_cache_strides, dtype
)
v_cache_spec = TensorSpec.from_tensor(
(n_kv_head, v_cache_buf_len, head_dim), v_cache_strides, dtype
)
pos_spec = TensorSpec.from_scalar(pos)
output_spec = TensorSpec.from_tensor(
output_shape, None, dtype
) # Output is always contiguous
# Inputs list
inputs = [q_spec, k_spec, v_spec, k_cache_spec, v_cache_spec, pos_spec]
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=inputs,
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Attention - OUT_OF_PLACE",
)
)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, v_cache_buf_len, head_dim)))
# Position (scalar)
inputs.append(TensorSpec.from_scalar(pos))
# Output tensor: (seq_len, n_q_head, head_dim)
output_shape = (seq_len, n_q_head, head_dim)
output = TensorSpec.from_tensor(output_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Test Case 2: In-place with explicit output tensor (attention(q, k, v, k_cache, v_cache, pos, out=output))
if output_supports_inplace:
test_cases.append(
TestCase(
inputs=inputs,
kwargs=None,
output_spec=output_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Attention - INPLACE(out)",
)
)
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-3, "rtol": 5e-2},
}
return test_cases
class OpTest(BaseOperatorTest):
"""Attention test with simplified test case parsing"""
"""Attention operator test with simplified implementation"""
def __init__(self):
super().__init__("Attention")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
return parse_test_cases()
def torch_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
"""PyTorch attention implementation"""
result = torch_attention(q, k, v, k_cache, v_cache, pos)
if out is not None:
out.set_(result)
out.copy_(result)
return out
else:
return result
return result
def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
"""InfiniCore attention implementation"""
return infinicore.attention(q, k, v, k_cache, v_cache, pos, out=out)
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
# Basic cases
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), None),
# Strided cases
((13, 4), None, None, (10, 1)),
((13, 4), (10, 1), (10, 1), (10, 1)),
# 3D cases
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), None),
# Broadcast cases
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
# Large tensors
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), None),
]
# Tolerance configuration - exact match required for bitwise operations
_TOLERANCE_MAP = {
infinicore.int8: {"atol": 0, "rtol": 0},
infinicore.int16: {"atol": 0, "rtol": 0},
infinicore.int32: {"atol": 0, "rtol": 0},
infinicore.int64: {"atol": 0, "rtol": 0},
infinicore.uint8: {"atol": 0, "rtol": 0},
infinicore.bool: {"atol": 0, "rtol": 0},
}
# Data types to test - integer types for bitwise operations
_TENSOR_DTYPES = [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
infinicore.bool, # XOR also supports boolean tensors
]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"BitwiseXor - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (bitwise_xor(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (bitwise_xor(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (bitwise_xor(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Bitwise XOR operator test with simplified implementation"""
def __init__(self):
super().__init__("BitwiseXor")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch bitwise_xor implementation"""
return torch.bitwise_xor(*args, **kwargs)
# def infinicore_operator(self, *args, **kwargs):
# """InfiniCore bitwise_xor implementation"""
# return infinicore.bitwise_xor(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
......@@ -16,52 +17,17 @@ from framework.runner import GenericTestRunner
# Causal softmax is a single-input function that applies causal masking before softmax
_TEST_CASES_DATA = [
# Basic 2D causal softmax
(TestCase.BOTH, (3, 3), None, None),
(TestCase.BOTH, (32, 512), None, None),
((3, 3), None, None),
((32, 512), None, None),
# Strided tensors
(TestCase.BOTH, (32, 512), (1024, 1), (1024, 1)),
((32, 512), (1024, 1), (1024, 1)),
# 3D causal softmax
(TestCase.BOTH, (32, 5, 5), None, None),
(TestCase.BOTH, (32, 20, 512), None, None),
(TestCase.BOTH, (32, 20, 512), (20480, 512, 1), None),
(TestCase.BOTH, (28, 15, 15), None, None),
((32, 5, 5), None, None),
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
]
def parse_test_cases(data):
"""
Parse causal_softmax test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode = data[0]
shape = data[1]
input_strides = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None
# Create input specifications
inputs = []
# Tensor input
if input_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if output_strides is not None:
output = TensorSpec.from_strided_tensor(shape, output_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
......@@ -69,6 +35,74 @@ _TOLERANCE_MAP = {
infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
}
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse causal_softmax test case data according to format:
(shape, input_strides, output_strides)
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
input_strides = data[1] if len(data) > 1 else None
output_strides = data[2] if len(data) > 2 else None
# Check if tensors support in-place operations
input_supports_inplace = not is_broadcast(input_strides)
output_supports_inplace = not is_broadcast(output_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
input_spec = TensorSpec.from_tensor(shape, input_strides, dtype)
output_spec = TensorSpec.from_tensor(shape, output_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Causal Softmax - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (causal_softmax(input, out=output))
if output_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=None,
output_spec=output_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Causal Softmax - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (causal_softmax(input, out=input))
if input_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"Causal Softmax - INPLACE(input)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""CausalSoftmax test with simplified test case parsing"""
......@@ -77,31 +111,28 @@ class OpTest(BaseOperatorTest):
super().__init__("CausalSoftmax")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
return parse_test_cases()
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, input, out=None, **kwargs):
def torch_causal_softmax(self, input, out=None, **kwargs):
# Causal softmax implementation: apply causal mask then softmax
dtype = input.dtype
# Create causal mask
mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1])
masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32))
result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype)
if out is not None:
out.copy_(result)
return out
return result
def infinicore_operator(self, input, out=None, **kwargs):
return infinicore.causal_softmax(input, out=out)
def torch_operator(self, *args, **kwargs):
return self.torch_causal_softmax(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
return infinicore.causal_softmax(*args, **kwargs)
def main():
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, input_strides, alpha)
_TEST_CASES_DATA = [
# Basic ELU tests without alpha (default alpha=1.0)
((13, 4), None, None),
((13, 4), (10, 1), None),
((13, 4), (0, 1), None),
# 3D tensor tests
((13, 4, 4), None, None),
((13, 4, 4), (20, 4, 1), None),
((13, 4, 4), (4, 0, 1), None),
# Large tensor tests
((16, 5632), None, None),
((16, 5632), (13312, 1), None),
# ELU with different alpha values
((8, 4), None, 0.5),
((8, 4), (10, 1), 0.5),
((8, 4), None, 1.5),
((8, 4), (10, 1), 1.5),
((16, 8), None, 2.0),
((16, 8), (20, 1), 2.0),
((16, 8), None, 0.3),
((16, 8), (20, 1), 0.3),
((32, 16), None, 1.0),
((32, 16), (40, 1), 1.0),
((32, 16), None, 1.8),
((32, 16), (40, 1), 1.8),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse ELU test case data according to format:
(shape, input_strides, alpha)
ELU only supports out-of-place and in-place modes via PyTorch's inplace parameter
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
input_strides = data[1] if len(data) > 1 else None
alpha = data[2] if len(data) > 2 else None
# Check if input tensor supports in-place operations
input_supports_inplace = not is_broadcast(input_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Create typed tensor spec
input_spec = TensorSpec.from_tensor(shape, input_strides, dtype)
# Build description
description_parts = ["ELU"]
if alpha is not None:
description_parts.append(f"alpha={alpha}")
if input_strides is not None:
description_parts.append(f"input_strides={input_strides}")
base_description = " - ".join(description_parts)
# Test Case 1: Out-of-place (return value)
kwargs = {}
if alpha is not None:
kwargs["alpha"] = alpha
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=kwargs,
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"{base_description} - OUT_OF_PLACE",
)
)
# Test Case 2: In-place operation using PyTorch's inplace parameter
if input_supports_inplace:
inplace_kwargs = {"inplace": True}
if alpha is not None:
inplace_kwargs["alpha"] = alpha
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=inplace_kwargs,
output_spec=None,
comparison_target=0, # Compare first input (modified in-place)
tolerance=tolerance,
description=f"{base_description} - INPLACE",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""ELU operator test with PyTorch-compatible implementation"""
def __init__(self):
super().__init__("ELU")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch ELU implementation"""
return torch.nn.functional.elu(*args, **kwargs)
def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
"""InfiniCore ELU implementation"""
return None
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
......@@ -7,109 +7,120 @@ import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
# Test cases format: (nbatch, m, n, k, a_strides, b_strides, c_strides)
# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n)
# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n)
_TEST_CASES_DATA = [
# Basic 2D matmul
(TestCase.BOTH, None, 2, 4, 3, None, None, None),
(TestCase.BOTH, None, 128, 64, 256, None, None, None),
(None, 2, 4, 3, None, None, None),
(None, 128, 64, 256, None, None, None),
# Batched matmul
(TestCase.BOTH, 2, 4, 2048, 2048, None, None, None),
(TestCase.BOTH, 4, 48, 6, 64, None, None, None),
(2, 4, 2048, 2048, None, None, None),
(4, 48, 6, 64, None, None, None),
# Strided tensors
(TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
(TestCase.BOTH, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
(None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
(None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
# Mixed cases
(TestCase.BOTH, 8, 16, 32, 16, None, None, None),
(8, 16, 32, 16, None, None, None),
]
def parse_test_cases(data):
"""
Parse matmul test case data according to format:
(operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
nbatch = data[1]
m, n, k = data[2], data[3], data[4]
a_strides = data[5] if len(data) > 5 else None
b_strides = data[6] if len(data) > 6 else None
c_strides = data[7] if len(data) > 7 else None
# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)
# Create input specifications
inputs = []
# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(a_shape))
# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(b_shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(c_shape, c_strides)
else:
output = TensorSpec.from_tensor(c_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for matmul operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
nbatch = data[0]
m, n, k = data[1], data[2], data[3]
a_strides = data[4] if len(data) > 4 else None
b_strides = data[5] if len(data) > 5 else None
c_strides = data[6] if len(data) > 6 else None
# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)
# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(a_shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(b_shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(c_shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Matmul - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (matmul(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Matmul - INPLACE(out)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Matmul test with simplified test case parsing"""
"""Matmul operator test with simplified implementation"""
def __init__(self):
super().__init__("Matmul")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
return parse_test_cases()
def torch_operator(self, a, b, out=None, **kwargs):
return torch.matmul(a, b, out=out)
def torch_operator(self, *args, **kwargs):
"""PyTorch matmul implementation"""
return torch.matmul(*args, **kwargs)
def infinicore_operator(self, a, b, out=None, **kwargs):
return infinicore.matmul(a, b, out=out)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore matmul implementation"""
return infinicore.matmul(*args, **kwargs)
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment