Commit 36081e53 authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/573 - supporting arguments to be tuples of tensors

parent a999ed68
...@@ -34,7 +34,7 @@ class TestCase: ...@@ -34,7 +34,7 @@ class TestCase:
Initialize a test case with complete configuration Initialize a test case with complete configuration
Args: Args:
inputs: List of TensorSpec objects or scalars inputs: List of TensorSpec objects, scalars, or tuples containing multiple TensorSpecs
kwargs: Additional keyword arguments for the operator kwargs: Additional keyword arguments for the operator
output_spec: TensorSpec for output tensor (for single output operations) output_spec: TensorSpec for output tensor (for single output operations)
output_specs: List of TensorSpec for multiple output tensors output_specs: List of TensorSpec for multiple output tensors
...@@ -45,10 +45,26 @@ class TestCase: ...@@ -45,10 +45,26 @@ class TestCase:
""" """
self.inputs = [] self.inputs = []
# Process inputs # Process inputs - support both single TensorSpecs and tuples of TensorSpecs
for inp in inputs: for inp in inputs:
if isinstance(inp, (list, tuple)): if isinstance(inp, (list, tuple)):
self.inputs.append(TensorSpec.from_tensor(inp)) # Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
processed_tuple = []
for item in inp:
if isinstance(item, (list, tuple)):
# Nested tuple - recursively process
nested_processed = []
for nested_item in item:
if isinstance(nested_item, TensorSpec):
nested_processed.append(nested_item)
else:
nested_processed.append(nested_item)
processed_tuple.append(tuple(nested_processed))
elif isinstance(item, TensorSpec):
processed_tuple.append(item)
else:
processed_tuple.append(item)
self.inputs.append(tuple(processed_tuple))
elif isinstance(inp, TensorSpec): elif isinstance(inp, TensorSpec):
self.inputs.append(inp) self.inputs.append(inp)
else: else:
...@@ -83,12 +99,43 @@ class TestCase: ...@@ -83,12 +99,43 @@ class TestCase:
for inp in self.inputs: for inp in self.inputs:
if isinstance(inp, TensorSpec) and not inp.is_scalar: if isinstance(inp, TensorSpec) and not inp.is_scalar:
count += 1 count += 1
elif isinstance(inp, (list, tuple)):
# Count all TensorSpecs within the tuple
for item in inp:
if isinstance(item, TensorSpec) and not item.is_scalar:
count += 1
return count return count
def __str__(self): def __str__(self):
input_strs = [] input_strs = []
for inp in self.inputs: for inp in self.inputs:
if hasattr(inp, "is_scalar") and inp.is_scalar: if isinstance(inp, (list, tuple)):
# Handle tuple inputs (e.g., for torch.cat)
tuple_strs = []
for item in inp:
if hasattr(item, "is_scalar") and item.is_scalar:
dtype_str = f", dtype={item.dtype}" if item.dtype else ""
tuple_strs.append(f"scalar({item.value}{dtype_str})")
elif hasattr(item, "shape"):
dtype_str = f", {item.dtype}" if item.dtype else ""
init_str = (
f", init={item.init_mode}"
if item.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(item, "strides") and item.strides:
strides_str = f", strides={item.strides}"
tuple_strs.append(
f"tensor{item.shape}{strides_str}{dtype_str}{init_str}"
)
else:
tuple_strs.append(
f"tensor{item.shape}{dtype_str}{init_str}"
)
else:
tuple_strs.append(str(item))
input_strs.append(f"tuple({'; '.join(tuple_strs)})")
elif hasattr(inp, "is_scalar") and inp.is_scalar:
dtype_str = f", dtype={inp.dtype}" if inp.dtype else "" dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
input_strs.append(f"scalar({inp.value}{dtype_str})") input_strs.append(f"scalar({inp.value}{dtype_str})")
elif hasattr(inp, "shape"): elif hasattr(inp, "shape"):
...@@ -111,7 +158,7 @@ class TestCase: ...@@ -111,7 +158,7 @@ class TestCase:
base_str = f"TestCase(" base_str = f"TestCase("
if self.description: if self.description:
base_str += f"{self.description}" base_str += f"{self.description}"
base_str += f" - inputs=[{', '.join(input_strs)}]" base_str += f" - inputs=[{'; '.join(input_strs)}]"
if self.kwargs or self.output_spec or self.output_specs: if self.kwargs or self.output_spec or self.output_specs:
kwargs_strs = [] kwargs_strs = []
...@@ -160,9 +207,9 @@ class TestCase: ...@@ -160,9 +207,9 @@ class TestCase:
) )
kwargs_strs.extend(output_strs) kwargs_strs.extend(output_strs)
base_str += f", kwargs={{{', '.join(kwargs_strs)}}}" base_str += f", kwargs={{{'; '.join(kwargs_strs)}}}"
base_str += f", outputs={self.output_count})" base_str += ")"
return base_str return base_str
...@@ -331,21 +378,43 @@ class BaseOperatorTest(ABC): ...@@ -331,21 +378,43 @@ class BaseOperatorTest(ABC):
"""InfiniCore operator function""" """InfiniCore operator function"""
raise NotImplementedError("infinicore_operator not implemented") raise NotImplementedError("infinicore_operator not implemented")
def _create_tensor_from_spec(self, spec, device):
"""Helper method to create tensor from TensorSpec"""
if isinstance(spec, TensorSpec):
if spec.is_scalar:
return spec.value
else:
return spec.create_torch_tensor(device)
return spec
def prepare_inputs_and_kwargs(self, test_case, device): def prepare_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors""" """Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
Supports tuple inputs for operators like torch.cat
"""
inputs = [] inputs = []
kwargs = test_case.kwargs.copy() kwargs = test_case.kwargs.copy()
# Prepare input tensors # Prepare input tensors - support both single TensorSpecs and tuples of TensorSpecs
for i, input_spec in enumerate(test_case.inputs): for input_spec in test_case.inputs:
if isinstance(input_spec, TensorSpec): if isinstance(input_spec, (list, tuple)):
if input_spec.is_scalar: # Handle tuple of multiple TensorSpecs (e.g., for torch.cat)
inputs.append(input_spec.value) tuple_tensors = []
else: for item in input_spec:
tensor = input_spec.create_torch_tensor(device) if isinstance(item, (list, tuple)):
inputs.append(tensor) # Handle nested tuples
nested_tensors = []
for nested_item in item:
nested_tensors.append(
self._create_tensor_from_spec(nested_item, device)
)
tuple_tensors.append(tuple(nested_tensors))
else:
tuple_tensors.append(
self._create_tensor_from_spec(item, device)
)
inputs.append(tuple(tuple_tensors))
else: else:
inputs.append(input_spec) inputs.append(self._create_tensor_from_spec(input_spec, device))
# Prepare output tensors based on output_count # Prepare output tensors based on output_count
if test_case.output_count == 1: if test_case.output_count == 1:
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (tensor_shapes, dim, input_strides_list, output_strides)
_TEST_CASES_DATA = [
# Basic concatenation
([(2, 3), (2, 3)], 0, None, None),
([(2, 3), (2, 3)], 1, None, None),
([(1, 4), (3, 4)], 0, None, None),
# Multiple tensors
([(1, 5), (2, 5), (3, 5)], 0, None, None),
([(3, 2), (3, 3), (3, 1)], 1, None, None),
# 3D tensors
([(2, 3, 4), (2, 3, 4)], 0, None, None),
([(2, 3, 4), (2, 3, 4)], 1, None, None),
([(2, 3, 4), (2, 3, 4)], 2, None, None),
# Strided tensors
([(3, 4), (3, 4)], 0, [(8, 1), (8, 1)], None),
([(2, 5), (2, 5)], 1, [(10, 1), (10, 1)], None),
# Large tensors
([(16, 256), (16, 256)], 0, None, None),
([(8, 512), (8, 512)], 1, None, None),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse cat test case data and return list of TestCase objects.
"""
test_cases = []
for data in _TEST_CASES_DATA:
tensor_shapes = data[0]
dim = data[1]
input_strides_list = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None
# Calculate output shape
output_shape = list(tensor_shapes[0])
for shape in tensor_shapes[1:]:
output_shape[dim] += shape[dim]
# Check if output supports in-place
output_supports_inplace = not is_broadcast(output_strides)
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create input tensor specs as tuple
input_specs = []
for i, shape in enumerate(tensor_shapes):
strides = (
input_strides_list[i]
if input_strides_list and i < len(input_strides_list)
else None
)
input_specs.append(TensorSpec.from_tensor(shape, strides, dtype))
# Create output tensor spec
output_spec = TensorSpec.from_tensor(output_shape, output_strides, dtype)
# Out-of-place test case
test_cases.append(
TestCase(
inputs=[tuple(input_specs)],
kwargs={"dim": dim},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Cat - OUT_OF_PLACE",
)
)
# In-place test case
if output_supports_inplace:
test_cases.append(
TestCase(
inputs=[tuple(input_specs)],
kwargs={"dim": dim},
output_spec=output_spec,
comparison_target="out",
tolerance=tolerance,
description=f"Cat - INPLACE(out)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Cat operator test implementation"""
def __init__(self):
super().__init__("Cat")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch cat implementation"""
return torch.cat(*args, **kwargs)
# def infinicore_operator(self, *args, **kwargs):
# """InfiniCore cat implementation"""
# return infinicore.cat(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment