"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "04c8f4b56e59704558755865a5f4c2f0a94e5c96"
Unverified Commit 39ec8f0e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #558 from InfiniTensor/issue/556

parents 2e5b2342 6b8949ce
...@@ -7,3 +7,5 @@ def add(input, other, *, out=None): ...@@ -7,3 +7,5 @@ def add(input, other, *, out=None):
return Tensor(_infinicore.add(input._underlying, other._underlying)) return Tensor(_infinicore.add(input._underlying, other._underlying))
_infinicore.add_(out._underlying, input._underlying, other._underlying) _infinicore.add_(out._underlying, input._underlying, other._underlying)
return out
...@@ -24,3 +24,5 @@ def attention(q, k, v, k_cache, v_cache, pos, *, out=None): ...@@ -24,3 +24,5 @@ def attention(q, k, v, k_cache, v_cache, pos, *, out=None):
v_cache._underlying, v_cache._underlying,
pos, pos,
) )
return out
...@@ -7,3 +7,5 @@ def causal_softmax(input, *, out=None): ...@@ -7,3 +7,5 @@ def causal_softmax(input, *, out=None):
return Tensor(_infinicore.causal_softmax(input._underlying)) return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying) _infinicore.causal_softmax_(out._underlying, input._underlying)
return out
...@@ -7,3 +7,5 @@ def matmul(input, other, *, out=None): ...@@ -7,3 +7,5 @@ def matmul(input, other, *, out=None):
return Tensor(_infinicore.matmul(input._underlying, other._underlying)) return Tensor(_infinicore.matmul(input._underlying, other._underlying))
_infinicore.matmul_(out._underlying, input._underlying, other._underlying) _infinicore.matmul_(out._underlying, input._underlying, other._underlying)
return out
...@@ -7,3 +7,5 @@ def rearrange(input, other, *, out=None): ...@@ -7,3 +7,5 @@ def rearrange(input, other, *, out=None):
return Tensor(_infinicore.rearrange(input._underlying)) return Tensor(_infinicore.rearrange(input._underlying))
_infinicore.rearrange_(out._underlying, input._underlying) _infinicore.rearrange_(out._underlying, input._underlying)
return out
...@@ -11,3 +11,5 @@ def rms_norm(input, weight, epsilon=1e-5, *, out=None): ...@@ -11,3 +11,5 @@ def rms_norm(input, weight, epsilon=1e-5, *, out=None):
_infinicore.rms_norm_( _infinicore.rms_norm_(
out._underlying, input._underlying, weight._underlying, epsilon out._underlying, input._underlying, weight._underlying, epsilon
) )
return out
...@@ -7,3 +7,5 @@ def silu(input, *, out=None): ...@@ -7,3 +7,5 @@ def silu(input, *, out=None):
return Tensor(_infinicore.silu(input._underlying)) return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying) _infinicore.silu_(out._underlying, input._underlying)
return out
...@@ -7,3 +7,5 @@ def swiglu(input, other, *, out=None): ...@@ -7,3 +7,5 @@ def swiglu(input, other, *, out=None):
return Tensor(_infinicore.swiglu(input._underlying, other._underlying)) return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying) _infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
# [file name]: __init__.py
# [file content begin]
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
from .utils import ( from .utils import (
...@@ -16,7 +14,6 @@ from .config import get_test_devices, get_args ...@@ -16,7 +14,6 @@ from .config import get_test_devices, get_args
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .runner import GenericTestRunner from .runner import GenericTestRunner
from .templates import BinaryOperatorTest, UnaryOperatorTest
__all__ = [ __all__ = [
"TensorSpec", "TensorSpec",
...@@ -41,6 +38,4 @@ __all__ = [ ...@@ -41,6 +38,4 @@ __all__ = [
"to_torch_dtype", "to_torch_dtype",
"to_infinicore_dtype", "to_infinicore_dtype",
"GenericTestRunner", "GenericTestRunner",
"BinaryOperatorTest",
"UnaryOperatorTest",
] ]
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import infinicore import infinicore
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Union, Callable, Optional from typing import List, Dict, Any, Optional
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
...@@ -11,28 +11,37 @@ from .utils import ( ...@@ -11,28 +11,37 @@ from .utils import (
create_test_comparator, create_test_comparator,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
profile_operation, profile_operation,
rearrange_tensor,
synchronize_device, synchronize_device,
convert_infinicore_to_torch,
) )
class TestCase: class TestCase:
"""Test case""" """Test case with all configuration included"""
OUT_OF_PLACE = "out_of_place" def __init__(
IN_PLACE = "in_place" self,
BOTH = "both" inputs,
kwargs=None,
def __init__(self, operation_mode, inputs, output=None, **kwargs): output_spec=None,
if operation_mode not in [self.IN_PLACE, self.OUT_OF_PLACE, self.BOTH]: comparison_target=None,
raise ValueError(f"Invalid operation_mode: {operation_mode}") description="",
tolerance=None,
if operation_mode == self.IN_PLACE and output is None: ):
raise ValueError("IN_PLACE mode requires output specification") """
Initialize a test case with complete configuration
self.operation_mode = operation_mode
Args:
inputs: List of TensorSpec objects or scalars
kwargs: Additional keyword arguments for the operator
output_spec: TensorSpec for output tensor (for in-place operations)
comparison_target: Target for comparison ('out', index, or None for return value)
description: Test case description
tolerance: Tolerance settings for this test case {'atol': float, 'rtol': float}
"""
self.inputs = [] self.inputs = []
# Process inputs
for inp in inputs: for inp in inputs:
if isinstance(inp, (list, tuple)): if isinstance(inp, (list, tuple)):
self.inputs.append(TensorSpec.from_tensor(inp)) self.inputs.append(TensorSpec.from_tensor(inp))
...@@ -41,34 +50,34 @@ class TestCase: ...@@ -41,34 +50,34 @@ class TestCase:
else: else:
self.inputs.append(inp) self.inputs.append(inp)
if isinstance(output, (list, tuple)): self.kwargs = kwargs or {}
self.output = TensorSpec.from_tensor(output) self.output_spec = output_spec
else: self.comparison_target = comparison_target
self.output = output self.description = description
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
self.kwargs = kwargs def get_tensor_input_count(self):
self.description = kwargs.pop("description", "") """Count the number of tensor inputs (excluding scalars)"""
count = 0
for inp in self.inputs:
if isinstance(inp, TensorSpec) and not inp.is_scalar:
count += 1
return count
def __str__(self): def __str__(self):
mode_str = self.operation_mode.upper()
input_strs = [] input_strs = []
for inp in self.inputs: for inp in self.inputs:
if hasattr(inp, "is_scalar") and inp.is_scalar: if hasattr(inp, "is_scalar") and inp.is_scalar:
dtype_str = f", dtype={inp.dtype}" if inp.dtype else "" dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
input_strs.append(f"scalar({inp.value}{dtype_str})") input_strs.append(f"scalar({inp.value}{dtype_str})")
elif hasattr(inp, "shape"): elif hasattr(inp, "shape"):
dtype_str = f", dtype={inp.dtype}" if inp.dtype else "" dtype_str = f", {inp.dtype}" if inp.dtype else ""
init_str = ( init_str = (
f", init={inp.init_mode}" f", init={inp.init_mode}"
if inp.init_mode != TensorInitializer.RANDOM if inp.init_mode != TensorInitializer.RANDOM
else "" else ""
) )
# Show shape and strides for non-contiguous tensors if hasattr(inp, "strides") and inp.strides:
if (
hasattr(inp, "is_contiguous")
and not inp.is_contiguous
and inp.strides
):
strides_str = f", strides={inp.strides}" strides_str = f", strides={inp.strides}"
input_strs.append( input_strs.append(
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}" f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
...@@ -78,28 +87,38 @@ class TestCase: ...@@ -78,28 +87,38 @@ class TestCase:
else: else:
input_strs.append(str(inp)) input_strs.append(str(inp))
base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]" base_str = f"TestCase("
if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = (
f", init={self.output.init_mode}"
if self.output.init_mode != TensorInitializer.RANDOM
else ""
)
# Show shape and strides for non-contiguous output tensors
if (
hasattr(self.output, "is_contiguous")
and not self.output.is_contiguous
and self.output.strides
):
strides_str = f", strides={self.output.strides}"
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
else:
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs:
base_str += f", kwargs={self.kwargs}"
if self.description: if self.description:
base_str += f", desc='{self.description}'" base_str += f"{self.description}"
base_str += f" - inputs=[{', '.join(input_strs)}]"
if self.kwargs or self.output_spec:
kwargs_strs = []
for key, value in self.kwargs.items():
if key == "out" and isinstance(value, int):
kwargs_strs.append(f"{key}={value}")
else:
kwargs_strs.append(f"{key}={value}")
output_spec = self.output_spec
if output_spec and isinstance(output_spec, TensorSpec):
dtype_str = f", {output_spec.dtype}" if output_spec.dtype else ""
init_str = (
f", init={output_spec.init_mode}"
if output_spec.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(output_spec, "strides") and output_spec.strides:
strides_str = f", strides={output_spec.strides}"
kwargs_strs.append(
f"out=tensor{output_spec.shape}{strides_str}{dtype_str}{init_str}"
)
else:
kwargs_strs.append(
f"out=tensor{output_spec.shape}{dtype_str}{init_str}"
)
base_str += f", kwargs={{{', '.join(kwargs_strs)}}}"
base_str += ")" base_str += ")"
return base_str return base_str
...@@ -107,23 +126,11 @@ class TestCase: ...@@ -107,23 +126,11 @@ class TestCase:
class TestConfig: class TestConfig:
"""Test configuration""" """Test configuration"""
def __init__( def __init__(self, debug=False, bench=False, num_prerun=10, num_iterations=1000):
self,
tensor_dtypes,
tolerance_map,
debug=False,
bench=False,
num_prerun=10,
num_iterations=1000,
dtype_combinations=None,
):
self.tensor_dtypes = tensor_dtypes
self.tolerance_map = tolerance_map
self.debug = debug self.debug = debug
self.bench = bench self.bench = bench
self.num_prerun = num_prerun self.num_prerun = num_prerun
self.num_iterations = num_iterations self.num_iterations = num_iterations
self.dtype_combinations = dtype_combinations
class TestRunner: class TestRunner:
...@@ -140,58 +147,21 @@ class TestRunner: ...@@ -140,58 +147,21 @@ class TestRunner:
print(f"Testing {test_type} on {InfiniDeviceNames[device]}") print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
print(f"{'='*60}") print(f"{'='*60}")
tensor_dtypes = self._filter_tensor_dtypes_by_device(
device, self.config.tensor_dtypes
)
for test_case in self.test_cases: for test_case in self.test_cases:
if self.config.dtype_combinations: try:
for dtype_combo in self.config.dtype_combinations: print(f"{test_case}")
try:
# Print test case info first test_func(device, test_case, self.config)
combo_str = self._format_dtype_combo(dtype_combo) print(f"\033[92m✓\033[0m Passed")
print(f"{test_case} with {combo_str}") except Exception as e:
error_msg = f"Error: {e}"
test_func(device, test_case, dtype_combo, self.config) print(f"\033[91m✗\033[0m {error_msg}")
print(f"\033[92m✓\033[0m Passed") self.failed_tests.append(error_msg)
except Exception as e: if self.config.debug:
combo_str = self._format_dtype_combo(dtype_combo) raise
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
else:
for dtype in tensor_dtypes:
try:
# Print test case info first
print(f"{test_case} with {dtype}")
test_func(device, test_case, dtype, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
return len(self.failed_tests) == 0 return len(self.failed_tests) == 0
def _format_dtype_combo(self, dtype_combo):
if isinstance(dtype_combo, dict):
return f"dtypes({dtype_combo})"
elif isinstance(dtype_combo, (list, tuple)):
return f"dtypes{tuple(dtype_combo)}"
else:
return str(dtype_combo)
def _filter_tensor_dtypes_by_device(self, device, tensor_dtypes):
if device in ():
return [dt for dt in tensor_dtypes if dt != infinicore.bfloat16]
else:
return tensor_dtypes
def print_summary(self): def print_summary(self):
if self.failed_tests: if self.failed_tests:
print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m") print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m")
...@@ -209,120 +179,100 @@ class BaseOperatorTest(ABC): ...@@ -209,120 +179,100 @@ class BaseOperatorTest(ABC):
def __init__(self, operator_name): def __init__(self, operator_name):
self.operator_name = operator_name self.operator_name = operator_name
self.test_cases = self.get_test_cases() self.test_cases = self.get_test_cases()
self.tensor_dtypes = self.get_tensor_dtypes()
self.tolerance_map = self.get_tolerance_map()
self.dtype_combinations = self.get_dtype_combinations()
@abstractmethod @abstractmethod
def get_test_cases(self): def get_test_cases(self):
"""Return list of TestCase objects""" """Return list of TestCase objects with complete configuration"""
pass pass
@abstractmethod def torch_operator(self, *args, **kwargs):
def get_tensor_dtypes(self): """PyTorch operator function"""
"""Return supported data types"""
pass
@abstractmethod
def get_tolerance_map(self):
"""Return tolerance configuration"""
pass
def get_dtype_combinations(self):
"""Return dtype combinations for mixed dtype tests"""
return None
def torch_operator(self, *inputs, out=None, **kwargs):
"""Unified PyTorch operator function - can be overridden or return None"""
raise NotImplementedError("torch_operator not implemented") raise NotImplementedError("torch_operator not implemented")
def infinicore_operator(self, *inputs, out=None, **kwargs): def infinicore_operator(self, *args, **kwargs):
"""Unified InfiniCore operator function - can be overridden or return None""" """InfiniCore operator function"""
raise NotImplementedError("infinicore_operator not implemented") raise NotImplementedError("infinicore_operator not implemented")
def create_strided_tensor( def prepare_inputs_and_kwargs(self, test_case, device):
self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM """Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
):
"""Create a non-contiguous tensor with specific strides"""
spec = TensorSpec.from_strided_tensor(shape, strides, dtype, init_mode)
return spec.create_torch_tensor(device, dtype)
def prepare_inputs(self, test_case, device, dtype_config):
"""Prepare input data"""
inputs = [] inputs = []
kwargs = test_case.kwargs.copy()
# Prepare input tensors
for i, input_spec in enumerate(test_case.inputs): for i, input_spec in enumerate(test_case.inputs):
if isinstance(input_spec, TensorSpec): if isinstance(input_spec, TensorSpec):
if input_spec.is_scalar: if input_spec.is_scalar:
inputs.append(input_spec.value) inputs.append(input_spec.value)
else: else:
tensor = input_spec.create_torch_tensor(device, dtype_config, i) tensor = input_spec.create_torch_tensor(device)
inputs.append(tensor) inputs.append(tensor)
else: else:
inputs.append(input_spec) inputs.append(input_spec)
return inputs, test_case.kwargs # Prepare output tensor if specified in output_spec
if test_case.output_spec is not None:
output_tensor = test_case.output_spec.create_torch_tensor(device)
kwargs["out"] = output_tensor
def get_output_dtype(self, test_case, dtype_config, torch_result=None): # Handle integer indices for in-place operations
"""Determine output dtype - returns infinicore dtype, not torch dtype""" if "out" in kwargs and isinstance(kwargs["out"], int):
if test_case.output and test_case.output.dtype is not None: input_idx = kwargs["out"]
return test_case.output.dtype if 0 <= input_idx < len(inputs) and isinstance(
elif isinstance(dtype_config, dict) and "output" in dtype_config: inputs[input_idx], torch.Tensor
return dtype_config["output"] ):
elif torch_result is not None: kwargs["out"] = inputs[input_idx]
return to_infinicore_dtype(torch_result.dtype)
else:
if isinstance(dtype_config, (list, tuple)):
return dtype_config[0]
else: else:
return dtype_config raise ValueError(
f"Invalid input index for in-place operation: {input_idx}"
def run_test(self, device, test_case, dtype_config, config):
"""Unified test execution flow"""
device_str = torch_device_map[device]
if test_case.operation_mode == TestCase.BOTH:
out_of_place_case = TestCase(
TestCase.OUT_OF_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
)
self._run_single_test(
device, out_of_place_case, dtype_config, config, "OUT_OF_PLACE"
)
if test_case.output is not None:
in_place_case = TestCase(
TestCase.IN_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
) )
self._run_single_test(
device, in_place_case, dtype_config, config, "IN_PLACE"
)
return
self._run_single_test( return inputs, kwargs
device, test_case, dtype_config, config, test_case.operation_mode.upper()
)
def _run_single_test(self, device, test_case, dtype_config, config, mode_name): def run_test(self, device, test_case, config):
"""Run a single test with specified operation mode""" """Unified test execution flow"""
device_str = torch_device_map[device] device_str = torch_device_map[device]
inputs, kwargs = self.prepare_inputs(test_case, device, dtype_config) # Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# For in-place operations on input tensors, we need to preserve the original state
original_inputs = []
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
# This is an in-place operation on an input tensor
# Store original values for comparison
for inp in inputs:
if isinstance(inp, torch.Tensor):
original_inputs.append(inp.clone().detach())
else:
original_inputs.append(inp)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs = [] infini_inputs = []
torch_input_clones = []
for inp in inputs: for inp in inputs:
if isinstance(inp, torch.Tensor): if isinstance(inp, torch.Tensor):
infini_tensor = infinicore_tensor_from_torch(inp) cloned_inp = inp.clone().detach()
torch_input_clones.append(cloned_inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor) infini_inputs.append(infini_tensor)
else: else:
infini_inputs.append(inp) infini_inputs.append(inp)
# Check if operators are implemented # Determine comparison target
comparison_target = test_case.comparison_target
# Handle infinicore output
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
if isinstance(comparison_target, int):
infini_kwargs["out"] = infini_inputs[comparison_target]
else:
cloned_out = infini_kwargs["out"].clone().detach()
torch_input_clones.append(cloned_out)
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
# Check operator implementations
torch_implemented = True torch_implemented = True
infini_implemented = True infini_implemented = True
...@@ -335,19 +285,19 @@ class BaseOperatorTest(ABC): ...@@ -335,19 +285,19 @@ class BaseOperatorTest(ABC):
torch_result = None torch_result = None
try: try:
infini_result = self.infinicore_operator(*infini_inputs, **kwargs) infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
if infini_result is None: if infini_result is None:
infini_implemented = False infini_implemented = False
except NotImplementedError: except NotImplementedError:
infini_implemented = False infini_implemented = False
infini_result = None infini_result = None
# If neither operator is implemented, skip the test # Skip if neither operator is implemented
if not torch_implemented and not infini_implemented: if not torch_implemented and not infini_implemented:
print(f"⚠ Both operators not implemented - test skipped") print(f"⚠ Both operators not implemented - test skipped")
return return
# If only one operator is implemented, run it without comparison # Single operator execution without comparison
if not torch_implemented or not infini_implemented: if not torch_implemented or not infini_implemented:
missing_op = ( missing_op = (
"torch_operator" if not torch_implemented else "infinicore_operator" "torch_operator" if not torch_implemented else "infinicore_operator"
...@@ -356,14 +306,12 @@ class BaseOperatorTest(ABC): ...@@ -356,14 +306,12 @@ class BaseOperatorTest(ABC):
f"⚠ {missing_op} not implemented - running single operator without comparison" f"⚠ {missing_op} not implemented - running single operator without comparison"
) )
# Run the available operator for benchmarking if requested
if config.bench: if config.bench:
if torch_implemented: if torch_implemented:
def torch_op(): def torch_op():
return self.torch_operator(*inputs, **kwargs) return self.torch_operator(*inputs, **kwargs)
print(f" {mode_name}:")
profile_operation( profile_operation(
"PyTorch ", "PyTorch ",
torch_op, torch_op,
...@@ -374,9 +322,8 @@ class BaseOperatorTest(ABC): ...@@ -374,9 +322,8 @@ class BaseOperatorTest(ABC):
if infini_implemented: if infini_implemented:
def infini_op(): def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs) return self.infinicore_operator(*infini_inputs, **infini_kwargs)
print(f" {mode_name}:")
profile_operation( profile_operation(
"InfiniCore", "InfiniCore",
infini_op, infini_op,
...@@ -386,126 +333,79 @@ class BaseOperatorTest(ABC): ...@@ -386,126 +333,79 @@ class BaseOperatorTest(ABC):
) )
return return
# Both operators are implemented - proceed with normal comparison if comparison_target is None:
if test_case.operation_mode == TestCase.OUT_OF_PLACE: # Compare return values (out-of-place)
torch_comparison = torch_result
def torch_op(): infini_comparison = infini_result
return self.torch_operator(*inputs, **kwargs) elif comparison_target == "out":
# Compare output tensor from kwargs (explicit output)
torch_result = torch_op() torch_comparison = kwargs.get("out")
infini_comparison = infini_kwargs.get("out")
if ( elif isinstance(comparison_target, int):
isinstance(torch_result, torch.Tensor) # Compare specific input tensor (in-place operation on input)
and not torch_result.is_contiguous() # For in-place operations, we compare the modified input tensor
): if 0 <= comparison_target < len(inputs):
torch_result = torch_result.contiguous() torch_comparison = inputs[comparison_target]
infini_comparison = infini_inputs[comparison_target]
def infini_op(): else:
return self.infinicore_operator(*infini_inputs, **kwargs) raise ValueError(
f"Invalid comparison target index: {comparison_target}"
infini_result = infini_op()
# Get comparison dtype (infinicore dtype)
comparison_dtype = self.get_output_dtype(
test_case, dtype_config, torch_result
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
) )
else: else:
if not test_case.output: raise ValueError(f"Invalid comparison target: {comparison_target}")
raise ValueError("IN_PLACE test requires output specification")
# Get output dtype and create output tensor # Validate comparison targets
output_dtype = self.get_output_dtype(test_case, dtype_config) if torch_comparison is None or infini_comparison is None:
output_shape = test_case.output.shape raise ValueError("Comparison targets cannot be None")
# Use TensorSpec to create output tensor with specified initialization mode # Perform comparison
if test_case.output.is_contiguous or test_case.output.strides is None: atol = test_case.tolerance.get("atol", 1e-5)
output_spec = TensorSpec.from_tensor( rtol = test_case.tolerance.get("rtol", 1e-3)
output_shape, output_dtype, init_mode=test_case.output.init_mode
)
else:
output_spec = TensorSpec.from_strided_tensor(
output_shape,
test_case.output.strides,
output_dtype,
init_mode=test_case.output.init_mode,
)
torch_output = output_spec.create_torch_tensor(device, output_dtype) compare_fn = create_test_comparator(config, atol, rtol, test_case.description)
# For non-contiguous tensors, we need to ensure zeros initialization is_valid = compare_fn(infini_comparison, torch_comparison)
if ( assert is_valid, f"Result comparison failed for {test_case}"
not test_case.output.is_contiguous
and test_case.output.strides is not None
):
torch_output.zero_()
def torch_op_inplace(): # Benchmarking
self.torch_operator(*inputs, out=torch_output, **kwargs) if config.bench:
if comparison_target is None:
# Out-of-place benchmarking
def torch_op():
return self.torch_operator(*inputs, **kwargs)
torch_op_inplace() def infini_op():
return self.infinicore_operator(*infini_inputs, **infini_kwargs)
# Create infinicore output tensor else:
torch_dummy = torch.zeros( # In-place benchmarking
output_shape, dtype=to_torch_dtype(output_dtype), device=device_str def torch_op():
) self.torch_operator(*inputs, **kwargs)
if ( return (
not test_case.output.is_contiguous kwargs.get("out")
and not test_case.output.strides is None if "out" in kwargs
): else inputs[comparison_target]
rearrange_tensor(torch_dummy, list(torch_output.stride())) )
infini_output = infinicore_tensor_from_torch(torch_dummy)
def infini_op_inplace():
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)
infini_op_inplace() def infini_op():
self.infinicore_operator(*infini_inputs, **infini_kwargs)
return (
infini_kwargs.get("out")
if "out" in infini_kwargs
else infini_inputs[comparison_target]
)
comparison_dtype = self.get_output_dtype( profile_operation(
test_case, dtype_config, torch_output "PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
) )
compare_fn = create_test_comparator( profile_operation(
config, comparison_dtype, mode_name=f"{mode_name}" "InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
) )
is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
"PyTorch ",
torch_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"InfiniCore",
infini_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
...@@ -20,6 +20,8 @@ def to_torch_dtype(infini_dtype): ...@@ -20,6 +20,8 @@ def to_torch_dtype(infini_dtype):
return torch.int64 return torch.int64
elif infini_dtype == infinicore.uint8: elif infini_dtype == infinicore.uint8:
return torch.uint8 return torch.uint8
elif infini_dtype == infinicore.bool:
return torch.bool
else: else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
...@@ -42,5 +44,7 @@ def to_infinicore_dtype(torch_dtype): ...@@ -42,5 +44,7 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.int64 return infinicore.int64
elif torch_dtype == torch.uint8: elif torch_dtype == torch.uint8:
return infinicore.uint8 return infinicore.uint8
elif torch_dtype == torch.bool:
return infinicore.bool
else: else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}") raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
...@@ -20,13 +20,10 @@ class GenericTestRunner: ...@@ -20,13 +20,10 @@ class GenericTestRunner:
def run(self): def run(self):
"""Execute the complete test suite""" """Execute the complete test suite"""
config = TestConfig( config = TestConfig(
tensor_dtypes=self.operator_test.tensor_dtypes,
tolerance_map=self.operator_test.tolerance_map,
debug=self.args.debug, debug=self.args.debug,
bench=self.args.bench, bench=self.args.bench,
num_prerun=self.args.num_prerun, num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations, num_iterations=self.args.num_iterations,
dtype_combinations=self.operator_test.dtype_combinations,
) )
runner = TestRunner(self.operator_test.test_cases, config) runner = TestRunner(self.operator_test.test_cases, config)
......
import torch import torch
import infinicore
from pathlib import Path from pathlib import Path
from .datatypes import to_torch_dtype from .datatypes import to_torch_dtype
from .devices import torch_device_map from .devices import torch_device_map
...@@ -18,150 +17,147 @@ class TensorInitializer: ...@@ -18,150 +17,147 @@ class TensorInitializer:
FROM_FILE = "from_file" FROM_FILE = "from_file"
@staticmethod @staticmethod
def create_tensor( def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, **kwargs):
shape, dtype, device, mode=RANDOM, strides=None, set_tensor=None, file_path=None
):
""" """
Create a torch tensor with specified initialization mode Unified tensor creation interface for both contiguous and non-contiguous tensors
Args: Args:
shape: Tensor shape shape: Tensor shape
dtype: infinicore dtype dtype: infinicore dtype
device: InfiniDeviceEnum device: InfiniDeviceEnum
mode: Initialization mode mode: Initialization mode
strides: Optional strides for strided tensors strides: Optional strides for non-contiguous tensors
set_tensor: Pre-existing tensor for manual/binary mode **kwargs: Additional arguments for specific modes
file_path: Path to file for FROM_FILE mode
Returns: Returns:
torch.Tensor: Initialized tensor torch.Tensor: Initialized tensor
""" """
# Convert InfiniDeviceEnum to torch device string
torch_device_str = torch_device_map[device] torch_device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype) torch_dtype = to_torch_dtype(dtype)
# Handle integer types differently for random initialization # Handle non-contiguous tensors
if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype):
mode = TensorInitializer.RANDINT # Use randint for integer types
# Handle strided tensors - calculate required storage size
if strides is not None: if strides is not None:
# Calculate the required storage size for strided tensor return TensorInitializer._create_strided_tensor(
storage_size = 0 shape, strides, torch_dtype, torch_device_str, mode, **kwargs
for i in range(len(shape)): )
if shape[i] > 0: else:
storage_size += (shape[i] - 1) * abs(strides[i]) return TensorInitializer._create_contiguous_tensor(
storage_size += 1 # Add 1 for the base element shape, torch_dtype, torch_device_str, mode, **kwargs
)
# Create base storage with sufficient size
if mode == TensorInitializer.RANDOM: @staticmethod
base_tensor = torch.rand( def _create_contiguous_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
storage_size, dtype=torch_dtype, device=torch_device_str """Create contiguous tensor"""
) if is_integer_dtype(torch_dtype):
elif mode == TensorInitializer.ZEROS: return TensorInitializer._create_integer_tensor(
base_tensor = torch.zeros( shape, torch_dtype, torch_device_str, mode, **kwargs
storage_size, dtype=torch_dtype, device=torch_device_str )
)
elif mode == TensorInitializer.ONES: if mode == TensorInitializer.RANDOM:
base_tensor = torch.ones( return torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
storage_size, dtype=torch_dtype, device=torch_device_str elif mode == TensorInitializer.ZEROS:
) return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT: elif mode == TensorInitializer.ONES:
# For integer types, use appropriate range return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
if is_integer_dtype(dtype): elif mode == TensorInitializer.RANDINT:
if dtype == infinicore.uint8: low = kwargs.get("low", -2000000000)
low, high = 0, 256 high = kwargs.get("high", 2000000000)
elif dtype == infinicore.int8: return torch.randint(
low, high = -128, 128 low, high, shape, dtype=torch_dtype, device=torch_device_str
elif dtype == infinicore.int16: )
low, high = -32768, 32768 elif mode == TensorInitializer.MANUAL:
else: # int32, int64, uint32 tensor = kwargs.get("set_tensor")
low, high = -1000, 1000 if tensor is None:
else: raise ValueError("Manual mode requires set_tensor")
low, high = -1000, 1000 if list(tensor.shape) != list(shape):
raise ValueError(
base_tensor = torch.randint( f"Shape mismatch: expected {shape}, got {tensor.shape}"
low,
high,
(storage_size,),
dtype=torch_dtype,
device=torch_device_str,
) )
elif mode == TensorInitializer.MANUAL: return tensor.to(torch_dtype).to(torch_device_str)
assert set_tensor is not None, "Manual mode requires set_tensor" elif mode == TensorInitializer.BINARY:
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str) tensor = kwargs.get("set_tensor")
elif mode == TensorInitializer.BINARY: if tensor is None:
assert set_tensor is not None, "Binary mode requires set_tensor" raise ValueError("Binary mode requires set_tensor")
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str) return tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE: elif mode == TensorInitializer.FROM_FILE:
base_tensor = TensorInitializer._load_from_file( return TensorInitializer._load_from_file(
file_path, storage_size, torch_dtype, torch_device_str kwargs.get("file_path"), shape, torch_dtype, torch_device_str
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
@staticmethod
def _create_integer_tensor(shape, torch_dtype, torch_device_str, mode, **kwargs):
if mode == TensorInitializer.RANDOM:
if torch_dtype == torch.bool:
return torch.randint(
0, 2, shape, dtype=torch_dtype, device=torch_device_str
).bool()
elif torch_dtype == torch.uint8:
return torch.randint(
0, 256, shape, dtype=torch_dtype, device=torch_device_str
) )
else: else:
raise ValueError(f"Unsupported initialization mode: {mode}") dtype_info = torch.iinfo(torch_dtype)
return torch.randint(
# Create strided view dtype_info.min,
tensor = torch.as_strided(base_tensor, shape, strides) dtype_info.max,
else:
# Contiguous tensor
if mode == TensorInitializer.RANDOM:
tensor = torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS:
tensor = torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
# For integer types, use appropriate range
if is_integer_dtype(dtype):
if dtype == infinicore.uint8:
low, high = 0, 256
elif dtype == infinicore.int8:
low, high = -128, 128
elif dtype == infinicore.int16:
low, high = -32768, 32768
else: # int32, int64, uint32
low, high = -1000, 1000
else:
low, high = -1000, 1000
tensor = torch.randint(
low,
high,
shape, shape,
dtype=torch_dtype, dtype=torch_dtype,
device=torch_device_str, device=torch_device_str,
) )
elif mode == TensorInitializer.MANUAL: elif mode == TensorInitializer.ZEROS:
assert set_tensor is not None, "Manual mode requires set_tensor" return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
assert shape == list(set_tensor.shape), "Shape mismatch in manual mode" elif mode == TensorInitializer.ONES:
tensor = set_tensor.to(torch_dtype).to(torch_device_str) return torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.BINARY: elif mode == TensorInitializer.RANDINT:
assert set_tensor is not None, "Binary mode requires set_tensor" low = kwargs.get("low", -100)
assert shape == list(set_tensor.shape), "Shape mismatch in binary mode" high = kwargs.get("high", 100)
tensor = set_tensor.to(torch_dtype).to(torch_device_str) return torch.randint(
elif mode == TensorInitializer.FROM_FILE: low, high, shape, dtype=torch_dtype, device=torch_device_str
tensor = TensorInitializer._load_from_file( )
file_path, shape, torch_dtype, torch_device_str elif mode == TensorInitializer.MANUAL:
tensor = kwargs.get("set_tensor")
if tensor is None:
raise ValueError("Manual mode requires set_tensor")
if list(tensor.shape) != list(shape):
raise ValueError(
f"Shape mismatch: expected {shape}, got {tensor.shape}"
) )
else: return tensor.to(torch_dtype).to(torch_device_str)
raise ValueError(f"Unsupported initialization mode: {mode}") elif mode == TensorInitializer.BINARY:
tensor = kwargs.get("set_tensor")
return tensor if tensor is None:
raise ValueError("Binary mode requires set_tensor")
return tensor.to(torch_dtype).to(torch_device_str)
else:
return torch.randint(
0, 100, shape, dtype=torch_dtype, device=torch_device_str
)
@staticmethod @staticmethod
def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str): def _create_strided_tensor(
""" shape, strides, torch_dtype, torch_device_str, mode, **kwargs
Load tensor data from file using PyTorch's native methods ):
"""Create non-contiguous tensor with specific strides"""
# Calculate required storage size
storage_size = 0
for i in range(len(shape)):
if shape[i] > 0:
storage_size += (shape[i] - 1) * abs(strides[i])
storage_size += 1
# Create base storage
base_tensor = TensorInitializer._create_contiguous_tensor(
(storage_size,), torch_dtype, torch_device_str, mode, **kwargs
)
Args: # Create strided view
file_path: Path to the file return torch.as_strided(base_tensor, shape, strides)
shape_or_size: Tensor shape for contiguous or size for strided
torch_dtype: Target torch dtype
torch_device_str: Target device string
Returns: @staticmethod
torch.Tensor: Tensor with data loaded from file def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str):
""" """Load tensor data from file"""
if file_path is None: if file_path is None:
raise ValueError("FROM_FILE mode requires file_path") raise ValueError("FROM_FILE mode requires file_path")
...@@ -169,21 +165,15 @@ class TensorInitializer: ...@@ -169,21 +165,15 @@ class TensorInitializer:
if not file_path.exists(): if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
# Determine file type and load accordingly
file_extension = file_path.suffix.lower() file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]: if file_extension in [".pt", ".pth"]:
# PyTorch native format
tensor = torch.load(file_path, map_location=torch_device_str) tensor = torch.load(file_path, map_location=torch_device_str)
elif file_extension in [".bin", ".dat", ".raw"]: elif file_extension in [".bin", ".dat", ".raw"]:
# Raw binary format - we need to know the expected shape
tensor = TensorInitializer._load_binary_file( tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str file_path, shape_or_size, torch_dtype, torch_device_str
) )
elif file_extension in [".npy"]: elif file_extension in [".npy"]:
# NumPy format - fallback to numpy if needed
try: try:
import numpy as np import numpy as np
...@@ -193,125 +183,48 @@ class TensorInitializer: ...@@ -193,125 +183,48 @@ class TensorInitializer:
) )
except ImportError: except ImportError:
raise ImportError("NumPy is required to load .npy files") raise ImportError("NumPy is required to load .npy files")
else: else:
# Try to load as PyTorch format first, then fallback to binary
try: try:
tensor = torch.load(file_path, map_location=torch_device_str) tensor = torch.load(file_path, map_location=torch_device_str)
except: except:
# Fallback to binary loading
tensor = TensorInitializer._load_binary_file( tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str file_path, shape_or_size, torch_dtype, torch_device_str
) )
# Ensure correct dtype and device
tensor = tensor.to(torch_dtype).to(torch_device_str) tensor = tensor.to(torch_dtype).to(torch_device_str)
# Validate shape/size # Validate shape/size
if isinstance(shape_or_size, (list, tuple)): if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor - check shape
if list(tensor.shape) != list(shape_or_size): if list(tensor.shape) != list(shape_or_size):
raise ValueError( raise ValueError(
f"Tensor shape mismatch: expected {shape_or_size}, got {tensor.shape}" f"Shape mismatch: expected {shape_or_size}, got {tensor.shape}"
) )
else: else:
# Strided tensor - check total size
if tensor.numel() != shape_or_size: if tensor.numel() != shape_or_size:
raise ValueError( raise ValueError(
f"Tensor size mismatch: expected {shape_or_size} elements, got {tensor.numel()}" f"Size mismatch: expected {shape_or_size} elements, got {tensor.numel()}"
) )
return tensor return tensor
@staticmethod @staticmethod
def _load_binary_file(file_path, shape_or_size, torch_dtype, torch_device_str): def _load_binary_file(file_path, shape_or_size, torch_dtype, torch_device_str):
""" """Load tensor from raw binary file"""
Load tensor from raw binary file
Args:
file_path: Path to binary file
shape_or_size: Expected shape or size
torch_dtype: Target dtype
torch_device_str: Target device
Returns:
torch.Tensor: Loaded tensor
"""
# Read binary data
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
binary_data = f.read() binary_data = f.read()
# Create tensor from buffer
if isinstance(shape_or_size, (list, tuple)): if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor with known shape
tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape( tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape(
shape_or_size shape_or_size
) )
else: else:
# Strided tensor - just 1D buffer
tensor = torch.frombuffer(binary_data, dtype=torch_dtype) tensor = torch.frombuffer(binary_data, dtype=torch_dtype)
return tensor.to(torch_device_str) return tensor.to(torch_device_str)
@staticmethod
def save_to_file(tensor, file_path, format="auto"):
"""
Save tensor data to file using PyTorch's native methods
Args:
tensor: torch.Tensor to save
file_path: Path to save the file
format: File format ('auto', 'torch', 'binary', 'numpy')
"""
file_path = Path(file_path)
if format == "auto":
# Determine format from file extension
file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]:
format = "torch"
elif file_extension in [".npy"]:
format = "numpy"
else:
format = "binary"
if format == "torch":
# PyTorch native format (preserves metadata)
torch.save(tensor, file_path)
elif format == "binary":
# Raw binary format
with open(file_path, "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
elif format == "numpy":
# NumPy format
try:
import numpy as np
np.save(file_path, tensor.cpu().numpy())
except ImportError:
raise ImportError("NumPy is required to save .npy files")
else:
raise ValueError(f"Unsupported format: {format}")
print(
f"Tensor saved to {file_path} (shape: {tensor.shape}, dtype: {tensor.dtype}, format: {format})"
)
@staticmethod
def list_supported_formats():
"""Return list of supported file formats"""
return {
"torch": [".pt", ".pth"], # PyTorch native format
"binary": [".bin", ".dat", ".raw"], # Raw binary
"numpy": [".npy"], # NumPy format
}
class TensorSpec: class TensorSpec:
"""Tensor specification supporting various input types and per-tensor dtype""" """Unified tensor specification for both contiguous and non-contiguous tensors"""
def __init__( def __init__(
self, self,
...@@ -320,43 +233,34 @@ class TensorSpec: ...@@ -320,43 +233,34 @@ class TensorSpec:
strides=None, strides=None,
value=None, value=None,
is_scalar=False, is_scalar=False,
is_contiguous=True, init_mode=TensorInitializer.RANDOM,
init_mode=TensorInitializer.RANDOM, # Default to random initialization **kwargs,
custom_tensor=None, # For manual/binary mode
file_path=None, # For FROM_FILE mode
file_format=None, # Optional file format hint
): ):
self.shape = shape self.shape = shape
self.dtype = dtype self.dtype = dtype
self.strides = strides self.strides = strides
self.value = value self.value = value
self.is_scalar = is_scalar self.is_scalar = is_scalar
self.is_contiguous = is_contiguous
self.init_mode = init_mode self.init_mode = init_mode
self.custom_tensor = custom_tensor self.kwargs = kwargs
self.file_path = file_path
self.file_format = file_format
@classmethod @classmethod
def from_tensor( def from_tensor(
cls, cls,
shape, shape,
dtype=None,
strides=None, strides=None,
is_contiguous=True, dtype=None,
init_mode=TensorInitializer.RANDOM, init_mode=TensorInitializer.RANDOM,
custom_tensor=None, **kwargs,
file_path=None,
): ):
"""Create tensor specification - unified interface for both contiguous and non-contiguous"""
return cls( return cls(
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
strides=strides, strides=strides,
is_scalar=False, is_scalar=False,
is_contiguous=is_contiguous,
init_mode=init_mode, init_mode=init_mode,
custom_tensor=custom_tensor, **kwargs,
file_path=file_path,
) )
@classmethod @classmethod
...@@ -365,84 +269,46 @@ class TensorSpec: ...@@ -365,84 +269,46 @@ class TensorSpec:
@classmethod @classmethod
def from_strided_tensor( def from_strided_tensor(
cls, cls, shape, strides, dtype=None, init_mode=TensorInitializer.RANDOM, **kwargs
shape,
strides,
dtype=None,
init_mode=TensorInitializer.RANDOM,
custom_tensor=None,
file_path=None,
):
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=False,
init_mode=init_mode,
custom_tensor=custom_tensor,
file_path=file_path,
)
@classmethod
def from_file(
cls,
file_path,
shape,
dtype=None,
strides=None,
is_contiguous=True,
file_format=None,
): ):
""" """Alias for from_tensor with explicit strides (for backward compatibility)"""
Create TensorSpec that loads data from file return cls.from_tensor(shape, strides, dtype, init_mode, **kwargs)
Args: def with_dtype(self, dtype):
file_path: Path to file """Create a new TensorSpec with the specified dtype"""
shape: Tensor shape return TensorSpec(
dtype: infinicore dtype (inferred from file if None) shape=self.shape,
strides: Optional strides for strided tensors
is_contiguous: Whether tensor is contiguous
file_format: Optional file format hint
Returns:
TensorSpec: Configured for file loading
"""
return cls(
shape=shape,
dtype=dtype, dtype=dtype,
strides=strides, strides=self.strides,
is_scalar=False, value=self.value,
is_contiguous=is_contiguous, is_scalar=self.is_scalar,
init_mode=TensorInitializer.FROM_FILE, init_mode=self.init_mode,
file_path=file_path, **self.kwargs,
file_format=file_format,
) )
def create_torch_tensor(self, device, dtype_config, tensor_index=0): def create_torch_tensor(self, device):
"""Create a torch tensor based on this specification""" """Create a torch tensor based on this specification"""
if self.is_scalar: if self.is_scalar:
return self.value return self.value
# Determine dtype - ensure we're using infinicore dtype, not torch dtype # Create tensor using unified interface
if self.dtype is not None:
tensor_dtype = self.dtype
elif isinstance(dtype_config, dict) and f"input_{tensor_index}" in dtype_config:
tensor_dtype = dtype_config[f"input_{tensor_index}"]
elif isinstance(dtype_config, (list, tuple)) and tensor_index < len(
dtype_config
):
tensor_dtype = dtype_config[tensor_index]
else:
tensor_dtype = dtype_config
# Create tensor using the specified initialization mode
return TensorInitializer.create_tensor( return TensorInitializer.create_tensor(
shape=self.shape, shape=self.shape,
dtype=tensor_dtype, dtype=self.dtype, # Use the dtype from the spec
device=device, device=device,
mode=self.init_mode, mode=self.init_mode,
strides=self.strides, strides=self.strides,
set_tensor=self.custom_tensor, **self.kwargs,
file_path=self.file_path,
) )
def is_tensor_input(self):
"""Check if this spec represents a tensor input (not scalar)"""
return not self.is_scalar
def __str__(self):
if self.is_scalar:
return f"scalar({self.value})"
else:
strides_str = f", strides={self.strides}" if self.strides else ""
dtype_str = f", dtype={self.dtype}" if self.dtype else ""
return f"tensor{self.shape}{strides_str}{dtype_str}"
...@@ -37,52 +37,25 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations): ...@@ -37,52 +37,25 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
print(f" {desc} time: {elapsed * 1000 :6f} ms") print(f" {desc} time: {elapsed * 1000 :6f} ms")
def is_integer_dtype(dtype): def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""Check if dtype is integer type"""
return dtype in [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
]
def is_float_dtype(dtype):
"""Check if dtype is floating point type"""
return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def debug(
actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None
):
""" """
Debug function to compare two tensors and print differences Debug function to compare two tensors and print differences
""" """
# Convert to float32 for bfloat16 comparison
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16: if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32) actual = actual.to(torch.float32)
desired = desired.to(torch.float32) desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype) print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
# Use appropriate comparison based on dtype import numpy as np
if dtype and is_integer_dtype(dtype):
# For integer types, require exact equality
import numpy as np
np.testing.assert_array_equal(actual.cpu(), desired.cpu()) np.testing.assert_allclose(
else: actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
# For float types, use allclose )
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy( def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
): ):
"""Print detailed tensor differences""" """Print detailed tensor differences"""
if actual.shape != expected.shape: if actual.shape != expected.shape:
...@@ -96,21 +69,13 @@ def print_discrepancy( ...@@ -96,21 +69,13 @@ def print_discrepancy(
actual_isnan = torch.isnan(actual) actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected) expected_isnan = torch.isnan(expected)
# Calculate difference mask based on dtype # Calculate difference mask
if dtype and is_integer_dtype(dtype): nan_mismatch = (
# For integer types, exact equality required actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
diff_mask = actual != expected )
else: diff_mask = nan_mismatch | (
# For float types, use tolerance-based comparison torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
nan_mismatch = ( )
actual_isnan ^ expected_isnan
if equal_nan
else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False) diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected delta = actual - expected
...@@ -142,9 +107,8 @@ def print_discrepancy( ...@@ -142,9 +107,8 @@ def print_discrepancy(
print(f" - Actual dtype: {actual.dtype}") print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}") print(f" - Desired dtype: {expected.dtype}")
if not (dtype and is_integer_dtype(dtype)): print(f" - Atol: {atol}")
print(f" - Atol: {atol}") print(f" - Rtol: {rtol}")
print(f" - Rtol: {rtol}")
print( print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)" f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
) )
...@@ -166,10 +130,6 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3 ...@@ -166,10 +130,6 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
""" """
Get tolerance settings based on data type Get tolerance settings based on data type
""" """
# For integer types, return zero tolerance (exact match required)
if is_integer_dtype(tensor_dtype):
return 0, 0
tolerance = tolerance_map.get( tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol} tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
) )
...@@ -202,6 +162,8 @@ def convert_infinicore_to_torch(infini_result, torch_reference): ...@@ -202,6 +162,8 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
Args: Args:
infini_result: infinicore tensor result infini_result: infinicore tensor result
torch_reference: PyTorch tensor reference (for shape and device) torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
device_str: torch device string
Returns: Returns:
torch.Tensor: PyTorch tensor with infinicore data torch.Tensor: PyTorch tensor with infinicore data
...@@ -217,70 +179,103 @@ def convert_infinicore_to_torch(infini_result, torch_reference): ...@@ -217,70 +179,103 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
def compare_results( def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
): ):
""" """
Generic function to compare infinicore result with PyTorch reference result Generic function to compare infinicore result with PyTorch reference result
Supports both floating-point (with tolerance) and integer (exact) comparison
Args: Args:
infini_result: infinicore tensor result infini_result: infinicore tensor result
torch_result: PyTorch tensor reference result torch_result: PyTorch tensor reference result
atol: absolute tolerance atol: absolute tolerance (for floating-point only)
rtol: relative tolerance rtol: relative tolerance (for floating-point only)
debug_mode: whether to enable debug output debug_mode: whether to enable debug output
dtype: infinicore data type for comparison logic
Returns: Returns:
bool: True if results match within tolerance bool: True if results match within tolerance (FP) or exactly (integer)
""" """
# Convert infinicore result to PyTorch tensor for comparison # Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result) torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
# Choose comparison method based on dtype # Handle scalar integer comparison
if dtype and is_integer_dtype(dtype): if isinstance(torch_result_from_infini, (int, float)) and isinstance(
# For integer types, require exact equality torch_result, (int, float)
result = torch.equal(torch_result_from_infini, torch_result) ):
else: if isinstance(torch_result_from_infini, int) and isinstance(torch_result, int):
# For float types, use tolerance-based comparison # Exact integer scalar comparison
result = torch.allclose( result_equal = torch_result_from_infini == torch_result
torch_result_from_infini, torch_result, atol=atol, rtol=rtol if debug_mode and not result_equal:
) print(
f"Integer scalar mismatch: {torch_result_from_infini} != {torch_result}"
)
return result_equal
else:
# Floating-point scalar comparison with tolerance
return abs(torch_result_from_infini - torch_result) <= atol + rtol * abs(
torch_result
)
# Debug mode: detailed comparison # Debug mode: detailed comparison
if debug_mode: if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype) debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
return result # Choose comparison method based on data type
if is_integer_dtype(torch_result_from_infini.dtype) or is_integer_dtype(
torch_result.dtype
):
# Exact equality for integer types
result_equal = torch.equal(torch_result_from_infini, torch_result)
if debug_mode and not result_equal:
print("Integer tensor comparison failed - requiring exact equality")
return result_equal
else:
# Tolerance-based comparison for floating-point types
return torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
)
def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""): def create_test_comparator(config, atol, rtol, mode_name=""):
""" """
Create a test-specific comparison function that handles test configuration Create a test-specific comparison function
Args: Args:
config: test configuration config: test configuration
dtype: infinicore data type atol: absolute tolerance (for floating-point only)
tolerance_map: optional tolerance map (defaults to config's tolerance_map) rtol: relative tolerance (for floating-point only)
mode_name: operation mode name for debug output mode_name: operation mode name for debug output
Returns: Returns:
callable: function that takes (infini_result, torch_result) and returns bool callable: function that takes (infini_result, torch_result) and returns bool
""" """
if tolerance_map is None:
tolerance_map = config.tolerance_map
atol, rtol = get_tolerance(tolerance_map, dtype)
def compare_test_results(infini_result, torch_result): def compare_test_results(infini_result, torch_result):
if config.debug and mode_name: if config.debug and mode_name:
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m") print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
# For integer types, override tolerance to require exact equality
actual_atol = atol
actual_rtol = rtol
# Check if we're dealing with integer types
try:
# Try to get dtype from infinicore tensor
if hasattr(infini_result, "dtype"):
infini_dtype = infini_result.dtype
torch_dtype = to_torch_dtype(infini_dtype)
if is_integer_dtype(torch_dtype):
actual_atol = 0
actual_rtol = 0
except:
pass
return compare_results( return compare_results(
infini_result, infini_result,
torch_result, torch_result,
atol=atol, atol=actual_atol,
rtol=rtol, rtol=actual_rtol,
debug_mode=config.debug, debug_mode=config.debug,
dtype=dtype,
) )
return compare_test_results return compare_test_results
...@@ -330,3 +325,30 @@ def rearrange_tensor(tensor, new_strides): ...@@ -330,3 +325,30 @@ def rearrange_tensor(tensor, new_strides):
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides)) new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
...@@ -7,93 +7,138 @@ import torch ...@@ -7,93 +7,138 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides) # Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
(TestCase.BOTH, (13, 4), None, None, None), # Basic cases
(TestCase.BOTH, (13, 4), (10, 1), (10, 1), (10, 1)), ((13, 4), None, None, None),
(TestCase.BOTH, (13, 4), (0, 1), None, None), ((13, 4), (10, 1), (10, 1), None),
(TestCase.BOTH, (13, 4, 4), None, None, None), # Strided cases
(TestCase.BOTH, (13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), ((13, 4), None, None, (10, 1)),
(TestCase.BOTH, (13, 4, 4), (4, 0, 1), (0, 4, 1), None), ((13, 4), (10, 1), (10, 1), (10, 1)),
(TestCase.BOTH, (16, 5632), None, None, None), # 3D cases
(TestCase.BOTH, (16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), None),
# Broadcast cases
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
# Large tensors
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), None),
] ]
# Tolerance configuration
def parse_test_cases(data):
"""
Parse add test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
shape = data[1]
a_strides = data[2] if len(data) > 2 else None
b_strides = data[3] if len(data) > 3 else None
c_strides = data[4] if len(data) > 4 else None
# Create input specifications
inputs = []
# Input tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Input tensor b (same shape as a)
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(shape, c_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2}, infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3}, infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
} }
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Add - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (add(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Add - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (add(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"Add - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (add(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"Add - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""Add test with simplified test case parsing""" """Add operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("Add") super().__init__("Add")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs): def torch_operator(self, *args, **kwargs):
return torch.add(a, b, out=out) """PyTorch add implementation"""
return torch.add(*args, **kwargs)
def infinicore_operator(self, a, b, out=None, **kwargs): def infinicore_operator(self, *args, **kwargs):
return infinicore.add(a, b, out=out) """InfiniCore add implementation"""
return infinicore.add(*args, **kwargs)
def main(): def main():
......
...@@ -11,18 +11,17 @@ import torch ...@@ -11,18 +11,17 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos, # Test cases format: (n_q_head, n_kv_head, seq_len, head_dim, pos, k_cache_buf_len, v_cache_buf_len,
# k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides, # q_strides, k_strides, v_strides, k_cache_strides, v_cache_strides)
# k_cache_strides, v_cache_strides)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# Prefill stage # Prefill stage
( (
TestCase.OUT_OF_PLACE,
32, 32,
4, 4,
5, 5,
...@@ -38,7 +37,6 @@ _TEST_CASES_DATA = [ ...@@ -38,7 +37,6 @@ _TEST_CASES_DATA = [
), ),
# Decode stage # Decode stage
( (
TestCase.OUT_OF_PLACE,
32, 32,
4, 4,
1, 1,
...@@ -53,10 +51,9 @@ _TEST_CASES_DATA = [ ...@@ -53,10 +51,9 @@ _TEST_CASES_DATA = [
[64, 11264, 1], [64, 11264, 1],
), ),
# Small test case # Small test case
(TestCase.OUT_OF_PLACE, 8, 4, 2, 16, 1, 8, 8, None, None, None, None, None), (8, 4, 2, 16, 1, 8, 8, None, None, None, None, None),
# Another prefill case # Another prefill case
( (
TestCase.OUT_OF_PLACE,
28, 28,
28, 28,
15, 15,
...@@ -137,124 +134,114 @@ def torch_attention(q, k, v, k_cache, v_cache, pos): ...@@ -137,124 +134,114 @@ def torch_attention(q, k, v, k_cache, v_cache, pos):
return attn_output return attn_output
def parse_test_cases(data): def parse_test_cases():
""" """
Parse attention test case data according to format: Parse attention test case data according to format:
(operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos, (n_q_head, n_kv_head, seq_len, head_dim, pos, k_cache_buf_len, v_cache_buf_len,
k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides, q_strides, k_strides, v_strides, k_cache_strides, v_cache_strides)
k_cache_strides, v_cache_strides)
""" """
operation_mode = data[0] test_cases = []
n_q_head, n_kv_head, seq_len, head_dim, pos = (
data[1], for data in _TEST_CASES_DATA:
data[2], n_q_head, n_kv_head, seq_len, head_dim, pos = (
data[3], data[0],
data[4], data[1],
data[5], data[2],
) data[3],
k_cache_buf_len, v_cache_buf_len = data[6], data[7] data[4],
q_strides = data[8] if len(data) > 8 else None
k_strides = data[9] if len(data) > 9 else None
v_strides = data[10] if len(data) > 10 else None
k_cache_strides = data[11] if len(data) > 11 else None
v_cache_strides = data[12] if len(data) > 12 else None
# Create input specifications
inputs = []
# Query tensor: (n_q_head, seq_len, head_dim)
if q_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_q_head, seq_len, head_dim), q_strides)
) )
else: k_cache_buf_len, v_cache_buf_len = data[5], data[6]
inputs.append(TensorSpec.from_tensor((n_q_head, seq_len, head_dim))) q_strides = data[7] if len(data) > 7 else None
k_strides = data[8] if len(data) > 8 else None
# Key tensor: (n_kv_head, seq_len, head_dim) v_strides = data[9] if len(data) > 9 else None
if k_strides is not None: k_cache_strides = data[10] if len(data) > 10 else None
inputs.append( v_cache_strides = data[11] if len(data) > 11 else None
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), k_strides)
) # Check if output tensor supports in-place operations
else: # For attention, output shape is (seq_len, n_q_head, head_dim)
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim))) output_shape = (seq_len, n_q_head, head_dim)
output_supports_inplace = True # Output is always contiguous for attention
# Value tensor: (n_kv_head, seq_len, head_dim)
if v_strides is not None: # Generate test cases for all data types
inputs.append( for dtype in [infinicore.float16, infinicore.bfloat16, infinicore.float32]:
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), v_strides) tolerance = {
) infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
else: infinicore.float32: {"atol": 1e-5, "rtol": 1e-3},
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim))) infinicore.bfloat16: {"atol": 1e-3, "rtol": 5e-2},
}.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Key cache: (n_kv_head, k_cache_buf_len, head_dim)
if k_cache_strides is not None: # Create typed tensor specs
inputs.append( q_spec = TensorSpec.from_tensor(
TensorSpec.from_strided_tensor( (n_q_head, seq_len, head_dim), q_strides, dtype
(n_kv_head, k_cache_buf_len, head_dim), k_cache_strides
) )
) k_spec = TensorSpec.from_tensor(
else: (n_kv_head, seq_len, head_dim), k_strides, dtype
inputs.append(TensorSpec.from_tensor((n_kv_head, k_cache_buf_len, head_dim))) )
v_spec = TensorSpec.from_tensor(
# Value cache: (n_kv_head, v_cache_buf_len, head_dim) (n_kv_head, seq_len, head_dim), v_strides, dtype
if v_cache_strides is not None: )
inputs.append( k_cache_spec = TensorSpec.from_tensor(
TensorSpec.from_strided_tensor( (n_kv_head, k_cache_buf_len, head_dim), k_cache_strides, dtype
(n_kv_head, v_cache_buf_len, head_dim), v_cache_strides )
v_cache_spec = TensorSpec.from_tensor(
(n_kv_head, v_cache_buf_len, head_dim), v_cache_strides, dtype
)
pos_spec = TensorSpec.from_scalar(pos)
output_spec = TensorSpec.from_tensor(
output_shape, None, dtype
) # Output is always contiguous
# Inputs list
inputs = [q_spec, k_spec, v_spec, k_cache_spec, v_cache_spec, pos_spec]
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=inputs,
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Attention - OUT_OF_PLACE",
)
) )
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, v_cache_buf_len, head_dim)))
# Position (scalar)
inputs.append(TensorSpec.from_scalar(pos))
# Output tensor: (seq_len, n_q_head, head_dim)
output_shape = (seq_len, n_q_head, head_dim)
output = TensorSpec.from_tensor(output_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types # Test Case 2: In-place with explicit output tensor (attention(q, k, v, k_cache, v_cache, pos, out=output))
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] if output_supports_inplace:
test_cases.append(
TestCase(
inputs=inputs,
kwargs=None,
output_spec=output_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Attention - INPLACE(out)",
)
)
# Tolerance return test_cases
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-3, "rtol": 5e-2},
}
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""Attention test with simplified test case parsing""" """Attention operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("Attention") super().__init__("Attention")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs): def torch_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
"""PyTorch attention implementation"""
result = torch_attention(q, k, v, k_cache, v_cache, pos) result = torch_attention(q, k, v, k_cache, v_cache, pos)
if out is not None: if out is not None:
out.set_(result) out.copy_(result)
return out return out
else: return result
return result
def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs): def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
"""InfiniCore attention implementation"""
return infinicore.attention(q, k, v, k_cache, v_cache, pos, out=out) return infinicore.attention(q, k, v, k_cache, v_cache, pos, out=out)
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
# Basic cases
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), None),
# Strided cases
((13, 4), None, None, (10, 1)),
((13, 4), (10, 1), (10, 1), (10, 1)),
# 3D cases
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), None),
# Broadcast cases
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
# Large tensors
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), None),
]
# Tolerance configuration - exact match required for bitwise operations
_TOLERANCE_MAP = {
infinicore.int8: {"atol": 0, "rtol": 0},
infinicore.int16: {"atol": 0, "rtol": 0},
infinicore.int32: {"atol": 0, "rtol": 0},
infinicore.int64: {"atol": 0, "rtol": 0},
infinicore.uint8: {"atol": 0, "rtol": 0},
infinicore.bool: {"atol": 0, "rtol": 0},
}
# Data types to test - integer types for bitwise operations
_TENSOR_DTYPES = [
infinicore.int8,
infinicore.int16,
infinicore.int32,
infinicore.int64,
infinicore.uint8,
infinicore.bool, # XOR also supports boolean tensors
]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"BitwiseXor - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (bitwise_xor(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (bitwise_xor(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (bitwise_xor(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"BitwiseXor - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Bitwise XOR operator test with simplified implementation"""
def __init__(self):
super().__init__("BitwiseXor")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch bitwise_xor implementation"""
return torch.bitwise_xor(*args, **kwargs)
# def infinicore_operator(self, *args, **kwargs):
# """InfiniCore bitwise_xor implementation"""
# return infinicore.bitwise_xor(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
...@@ -16,52 +17,17 @@ from framework.runner import GenericTestRunner ...@@ -16,52 +17,17 @@ from framework.runner import GenericTestRunner
# Causal softmax is a single-input function that applies causal masking before softmax # Causal softmax is a single-input function that applies causal masking before softmax
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# Basic 2D causal softmax # Basic 2D causal softmax
(TestCase.BOTH, (3, 3), None, None), ((3, 3), None, None),
(TestCase.BOTH, (32, 512), None, None), ((32, 512), None, None),
# Strided tensors # Strided tensors
(TestCase.BOTH, (32, 512), (1024, 1), (1024, 1)), ((32, 512), (1024, 1), (1024, 1)),
# 3D causal softmax # 3D causal softmax
(TestCase.BOTH, (32, 5, 5), None, None), ((32, 5, 5), None, None),
(TestCase.BOTH, (32, 20, 512), None, None), ((32, 20, 512), None, None),
(TestCase.BOTH, (32, 20, 512), (20480, 512, 1), None), ((32, 20, 512), (20480, 512, 1), None),
(TestCase.BOTH, (28, 15, 15), None, None), ((28, 15, 15), None, None),
] ]
def parse_test_cases(data):
"""
Parse causal_softmax test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode = data[0]
shape = data[1]
input_strides = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None
# Create input specifications
inputs = []
# Tensor input
if input_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if output_strides is not None:
output = TensorSpec.from_strided_tensor(shape, output_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance # Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
...@@ -69,6 +35,74 @@ _TOLERANCE_MAP = { ...@@ -69,6 +35,74 @@ _TOLERANCE_MAP = {
infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2}, infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
} }
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse causal_softmax test case data according to format:
(shape, input_strides, output_strides)
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
input_strides = data[1] if len(data) > 1 else None
output_strides = data[2] if len(data) > 2 else None
# Check if tensors support in-place operations
input_supports_inplace = not is_broadcast(input_strides)
output_supports_inplace = not is_broadcast(output_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
input_spec = TensorSpec.from_tensor(shape, input_strides, dtype)
output_spec = TensorSpec.from_tensor(shape, output_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Causal Softmax - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (causal_softmax(input, out=output))
if output_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=None,
output_spec=output_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Causal Softmax - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (causal_softmax(input, out=input))
if input_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"Causal Softmax - INPLACE(input)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""CausalSoftmax test with simplified test case parsing""" """CausalSoftmax test with simplified test case parsing"""
...@@ -77,31 +111,28 @@ class OpTest(BaseOperatorTest): ...@@ -77,31 +111,28 @@ class OpTest(BaseOperatorTest):
super().__init__("CausalSoftmax") super().__init__("CausalSoftmax")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self): def torch_causal_softmax(self, input, out=None, **kwargs):
return _TOLERANCE_MAP
def torch_operator(self, input, out=None, **kwargs):
# Causal softmax implementation: apply causal mask then softmax # Causal softmax implementation: apply causal mask then softmax
dtype = input.dtype dtype = input.dtype
# Create causal mask # Create causal mask
mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1])
masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32))
result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype) result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype)
if out is not None: if out is not None:
out.copy_(result) out.copy_(result)
return out return out
return result return result
def infinicore_operator(self, input, out=None, **kwargs): def torch_operator(self, *args, **kwargs):
return infinicore.causal_softmax(input, out=out) return self.torch_causal_softmax(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
return infinicore.causal_softmax(*args, **kwargs)
def main(): def main():
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, input_strides, alpha)
_TEST_CASES_DATA = [
# Basic ELU tests without alpha (default alpha=1.0)
((13, 4), None, None),
((13, 4), (10, 1), None),
((13, 4), (0, 1), None),
# 3D tensor tests
((13, 4, 4), None, None),
((13, 4, 4), (20, 4, 1), None),
((13, 4, 4), (4, 0, 1), None),
# Large tensor tests
((16, 5632), None, None),
((16, 5632), (13312, 1), None),
# ELU with different alpha values
((8, 4), None, 0.5),
((8, 4), (10, 1), 0.5),
((8, 4), None, 1.5),
((8, 4), (10, 1), 1.5),
((16, 8), None, 2.0),
((16, 8), (20, 1), 2.0),
((16, 8), None, 0.3),
((16, 8), (20, 1), 0.3),
((32, 16), None, 1.0),
((32, 16), (40, 1), 1.0),
((32, 16), None, 1.8),
((32, 16), (40, 1), 1.8),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse ELU test case data according to format:
(shape, input_strides, alpha)
ELU only supports out-of-place and in-place modes via PyTorch's inplace parameter
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
input_strides = data[1] if len(data) > 1 else None
alpha = data[2] if len(data) > 2 else None
# Check if input tensor supports in-place operations
input_supports_inplace = not is_broadcast(input_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Create typed tensor spec
input_spec = TensorSpec.from_tensor(shape, input_strides, dtype)
# Build description
description_parts = ["ELU"]
if alpha is not None:
description_parts.append(f"alpha={alpha}")
if input_strides is not None:
description_parts.append(f"input_strides={input_strides}")
base_description = " - ".join(description_parts)
# Test Case 1: Out-of-place (return value)
kwargs = {}
if alpha is not None:
kwargs["alpha"] = alpha
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=kwargs,
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"{base_description} - OUT_OF_PLACE",
)
)
# Test Case 2: In-place operation using PyTorch's inplace parameter
if input_supports_inplace:
inplace_kwargs = {"inplace": True}
if alpha is not None:
inplace_kwargs["alpha"] = alpha
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=inplace_kwargs,
output_spec=None,
comparison_target=0, # Compare first input (modified in-place)
tolerance=tolerance,
description=f"{base_description} - INPLACE",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""ELU operator test with PyTorch-compatible implementation"""
def __init__(self):
super().__init__("ELU")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch ELU implementation"""
return torch.nn.functional.elu(*args, **kwargs)
def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
"""InfiniCore ELU implementation"""
return None
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
...@@ -7,109 +7,120 @@ import torch ...@@ -7,109 +7,120 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides) # Test cases format: (nbatch, m, n, k, a_strides, b_strides, c_strides)
# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n) # If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n)
# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n) # If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# Basic 2D matmul # Basic 2D matmul
(TestCase.BOTH, None, 2, 4, 3, None, None, None), (None, 2, 4, 3, None, None, None),
(TestCase.BOTH, None, 128, 64, 256, None, None, None), (None, 128, 64, 256, None, None, None),
# Batched matmul # Batched matmul
(TestCase.BOTH, 2, 4, 2048, 2048, None, None, None), (2, 4, 2048, 2048, None, None, None),
(TestCase.BOTH, 4, 48, 6, 64, None, None, None), (4, 48, 6, 64, None, None, None),
# Strided tensors # Strided tensors
(TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)), (None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
(TestCase.BOTH, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)), (None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
# Mixed cases # Mixed cases
(TestCase.BOTH, 8, 16, 32, 16, None, None, None), (8, 16, 32, 16, None, None, None),
] ]
# Tolerance configuration
def parse_test_cases(data):
"""
Parse matmul test case data according to format:
(operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
nbatch = data[1]
m, n, k = data[2], data[3], data[4]
a_strides = data[5] if len(data) > 5 else None
b_strides = data[6] if len(data) > 6 else None
c_strides = data[7] if len(data) > 7 else None
# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)
# Create input specifications
inputs = []
# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(a_shape))
# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(b_shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(c_shape, c_strides)
else:
output = TensorSpec.from_tensor(c_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2}, infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3}, infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
} }
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for matmul operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
nbatch = data[0]
m, n, k = data[1], data[2], data[3]
a_strides = data[4] if len(data) > 4 else None
b_strides = data[5] if len(data) > 5 else None
c_strides = data[6] if len(data) > 6 else None
# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)
# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(a_shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(b_shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(c_shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Matmul - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (matmul(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"Matmul - INPLACE(out)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""Matmul test with simplified test case parsing""" """Matmul operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("Matmul") super().__init__("Matmul")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs): def torch_operator(self, *args, **kwargs):
return torch.matmul(a, b, out=out) """PyTorch matmul implementation"""
return torch.matmul(*args, **kwargs)
def infinicore_operator(self, a, b, out=None, **kwargs): def infinicore_operator(self, *args, **kwargs):
return infinicore.matmul(a, b, out=out) """InfiniCore matmul implementation"""
return infinicore.matmul(*args, **kwargs)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment