Commit 2a343a3a authored by wooway777's avatar wooway777
Browse files

issue/630 - slightly improved unimplemented messages

parent f69f6909
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
get_args,
......@@ -32,6 +33,7 @@ __all__ = [
"TensorSpec",
"TestCase",
"TestConfig",
"TestResult",
"TestRunner",
# Core functions
"compare_results",
......
"""
Core base classes for operator testing framework.
Contains TestConfig, TestRunner, and BaseOperatorTest classes.
"""
import torch
import infinicore
import traceback
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
from .test_case import TestCase, TestResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
......@@ -15,163 +19,6 @@ from .utils import (
from .benchmark import BenchmarkUtils
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
def __init__(
self,
inputs,
kwargs=None,
output_spec=None,
output_specs=None,
comparison_target=None,
description="",
tolerance=None,
output_count=1,
):
"""
Initialize a test case with complete configuration
Args:
inputs: List of TensorSpec objects, scalars, or tuples containing multiple TensorSpecs
kwargs: Additional keyword arguments for the operator
output_spec: TensorSpec for output tensor (for single output operations)
output_specs: List of TensorSpec for multiple output tensors
comparison_target: Target for comparison ('out', index, or None for return value)
description: Test case description
tolerance: Tolerance settings for this test case {'atol': float, 'rtol': float}
output_count: Number of outputs (default: 1)
"""
self.inputs = []
# Process inputs - support both single TensorSpecs and tuples of TensorSpecs
for i, inp in enumerate(inputs):
if isinstance(inp, (list, tuple)):
# Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
processed_tuple = []
for j, item in enumerate(inp):
if isinstance(item, (list, tuple)):
# Nested tuple - recursively process
nested_processed = []
for k, nested_item in enumerate(item):
if isinstance(nested_item, TensorSpec):
nested_item.fill_name(f"in_{i}_{j}_{k}")
nested_processed.append(nested_item)
else:
nested_processed.append(nested_item)
processed_tuple.append(tuple(nested_processed))
elif isinstance(item, TensorSpec):
item.fill_name(f"in_{i}_{j}")
processed_tuple.append(item)
else:
processed_tuple.append(item)
self.inputs.append(tuple(processed_tuple))
elif isinstance(inp, TensorSpec):
inp.fill_name(f"in_{i}")
self.inputs.append(inp)
else:
self.inputs.append(inp)
self.kwargs = kwargs or {}
self.output_spec = output_spec
self.output_specs = output_specs
self.comparison_target = comparison_target
self.description = description
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
self.output_count = output_count
if self.output_count > 1 and self.output_specs is not None:
for idx, spec in enumerate(self.output_specs):
spec.fill_name(f"out_{idx}")
# Validate output configuration
if self.output_count == 1:
if self.output_specs is not None:
raise ValueError("output_specs cannot be used when output_count=1")
else:
if self.output_spec is not None:
raise ValueError("output_spec cannot be used when output_count>1")
if (
self.output_specs is not None
and len(self.output_specs) != self.output_count
):
raise ValueError(
f"output_specs count ({len(self.output_specs)}) must match output_count ({self.output_count})"
)
def get_tensor_input_count(self):
"""Count the number of tensor inputs (excluding scalars)"""
count = 0
for inp in self.inputs:
if isinstance(inp, TensorSpec) and not inp.is_scalar:
count += 1
elif isinstance(inp, (list, tuple)):
# Count all TensorSpecs within the tuple
for item in inp:
if isinstance(item, TensorSpec) and not item.is_scalar:
count += 1
return count
def __str__(self):
input_strs = []
for inp in self.inputs:
if isinstance(inp, (list, tuple)):
# Handle tuple inputs (e.g., for torch.cat)
tuple_strs = []
for item in inp:
if isinstance(item, (list, tuple)):
# Handle nested tuples
nested_strs = []
for nested_item in item:
nested_strs.append(str(nested_item))
tuple_strs.append(f"tuple({', '.join(nested_strs)})")
else:
tuple_strs.append(str(item))
input_strs.append(f"tuple({'; '.join(tuple_strs)})")
else:
input_strs.append(str(inp))
base_str = f"TestCase("
if self.description:
base_str += f"{self.description}"
base_str += f" - inputs=[{'; '.join(input_strs)}]"
if self.kwargs or self.output_spec or self.output_specs:
kwargs_strs = []
for key, value in self.kwargs.items():
if key == "out" and isinstance(value, int):
kwargs_strs.append(f"{key}={self.inputs[value].name}")
else:
kwargs_strs.append(f"{key}={value}")
# Handle output specifications using TensorSpec's __str__
if self.output_count == 1 and self.output_spec:
kwargs_strs.append(f"out={self.output_spec}")
elif self.output_count > 1 and self.output_specs:
for i, spec in enumerate(self.output_specs):
kwargs_strs.append(f"out_{i}={spec}")
base_str += f", kwargs={{{'; '.join(kwargs_strs)}}}"
base_str += ")"
return base_str
class TestConfig:
"""Test configuration"""
......@@ -245,20 +92,20 @@ class TestRunner:
)
print(f"\033[92m✓\033[0m Passed")
elif test_result.return_code == -1:
fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - Test terminated in verbose mode."
# Test failed - use the actual error message from test_result
fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - {test_result.error_message}"
self.failed_tests.append(fail_msg)
print(f"\033[91m✗\033[0m {test_result.error_message}")
elif test_result.return_code == -2: # Skipped
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
# Both operators not implemented - use actual error message
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - {test_result.error_message}"
self.skipped_tests.append(skip_msg)
print(
f"\033[93m⚠\033[0m Both operators not implemented - test skipped"
)
print(f"\033[93m⚠\033[0m {test_result.error_message}")
elif test_result.return_code == -3: # Partial
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
# One operator not implemented - use actual error message
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - {test_result.error_message}"
self.partial_tests.append(partial_msg)
print(
f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison"
)
print(f"\033[93m⚠\033[0m {test_result.error_message}")
if self.config.verbose and test_result.return_code != 0:
return False
......@@ -569,43 +416,57 @@ class BaseOperatorTest(ABC):
# Check operator implementations
torch_implemented = True
infini_implemented = True
torch_error_msg = ""
infini_error_msg = ""
try:
torch_result = self.torch_operator(*inputs, **kwargs)
if torch_result is None:
torch_implemented = False
except NotImplementedError:
except NotImplementedError as e:
if config.verbose:
traceback.print_exc()
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "torch_operator not implemented"
return test_result
torch_implemented = False
torch_result = None
torch_error_msg = str(e)
try:
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
if infini_result is None:
infini_implemented = False
except NotImplementedError:
except NotImplementedError as e:
if config.verbose:
traceback.print_exc()
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "infinicore_operator not implemented"
return test_result
infini_implemented = False
infini_result = None
infini_error_msg = str(e)
if not torch_error_msg:
torch_error_msg = "unimplemented test function"
if not infini_error_msg:
infini_error_msg = "unimplemented test function"
# Skip if neither operator is implemented
if not torch_implemented and not infini_implemented:
test_result.return_code = -2 # Skipped
# Combine both error messages
test_result.error_message = f"Both operators failed: PyTorch - {torch_error_msg}; InfiniCore - {infini_error_msg}"
return test_result
# Single operator execution without comparison
if not torch_implemented or not infini_implemented:
test_result.return_code = -3 # Partial
# Determine which operator is missing and create appropriate message with actual error
if not torch_implemented:
test_result.error_message = (
f"PyTorch operator failed: {torch_error_msg}"
)
else:
test_result.error_message = (
f"InfiniCore operator failed: {infini_error_msg}"
)
# Run benchmarking for partial tests if enabled
if config.bench:
torch_host, torch_device, infini_host, infini_device = (
......
"""
Test case definitions and related functionality for the InfiniCore testing framework.
"""
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from .tensor import TensorSpec
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
def __init__(
self,
inputs,
kwargs=None,
output_spec=None,
output_specs=None,
comparison_target=None,
description="",
tolerance=None,
output_count=1,
):
"""
Initialize a test case with complete configuration
Args:
inputs: List of TensorSpec objects, scalars, or tuples containing multiple TensorSpecs
kwargs: Additional keyword arguments for the operator
output_spec: TensorSpec for output tensor (for single output operations)
output_specs: List of TensorSpec for multiple output tensors
comparison_target: Target for comparison ('out', index, or None for return value)
description: Test case description
tolerance: Tolerance settings for this test case {'atol': float, 'rtol': float}
output_count: Number of outputs (default: 1)
"""
self.inputs = []
# Process inputs - support both single TensorSpecs and tuples of TensorSpecs
for i, inp in enumerate(inputs):
if isinstance(inp, (list, tuple)):
# Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
processed_tuple = []
for j, item in enumerate(inp):
if isinstance(item, (list, tuple)):
# Nested tuple - recursively process
nested_processed = []
for k, nested_item in enumerate(item):
if isinstance(nested_item, TensorSpec):
nested_item.fill_name(f"in_{i}_{j}_{k}")
nested_processed.append(nested_item)
else:
nested_processed.append(nested_item)
processed_tuple.append(tuple(nested_processed))
elif isinstance(item, TensorSpec):
item.fill_name(f"in_{i}_{j}")
processed_tuple.append(item)
else:
processed_tuple.append(item)
self.inputs.append(tuple(processed_tuple))
elif isinstance(inp, TensorSpec):
inp.fill_name(f"in_{i}")
self.inputs.append(inp)
else:
self.inputs.append(inp)
self.kwargs = kwargs or {}
self.output_spec = output_spec
self.output_specs = output_specs
self.comparison_target = comparison_target
self.description = description
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
self.output_count = output_count
if self.output_count > 1 and self.output_specs is not None:
for idx, spec in enumerate(self.output_specs):
spec.fill_name(f"out_{idx}")
# Validate output configuration
if self.output_count == 1:
if self.output_specs is not None:
raise ValueError("output_specs cannot be used when output_count=1")
else:
if self.output_spec is not None:
raise ValueError("output_spec cannot be used when output_count>1")
if (
self.output_specs is not None
and len(self.output_specs) != self.output_count
):
raise ValueError(
f"output_specs count ({len(self.output_specs)}) must match output_count ({self.output_count})"
)
def get_tensor_input_count(self):
"""Count the number of tensor inputs (excluding scalars)"""
count = 0
for inp in self.inputs:
if isinstance(inp, TensorSpec) and not inp.is_scalar:
count += 1
elif isinstance(inp, (list, tuple)):
# Count all TensorSpecs within the tuple
for item in inp:
if isinstance(item, TensorSpec) and not item.is_scalar:
count += 1
return count
def __str__(self):
input_strs = []
for inp in self.inputs:
if isinstance(inp, (list, tuple)):
# Handle tuple inputs (e.g., for torch.cat)
tuple_strs = []
for item in inp:
if isinstance(item, (list, tuple)):
# Handle nested tuples
nested_strs = []
for nested_item in item:
nested_strs.append(str(nested_item))
tuple_strs.append(f"tuple({', '.join(nested_strs)})")
else:
tuple_strs.append(str(item))
input_strs.append(f"tuple({'; '.join(tuple_strs)})")
else:
input_strs.append(str(inp))
base_str = f"TestCase("
if self.description:
base_str += f"{self.description}"
base_str += f" - inputs=[{'; '.join(input_strs)}]"
if self.kwargs or self.output_spec or self.output_specs:
kwargs_strs = []
for key, value in self.kwargs.items():
if key == "out" and isinstance(value, int):
kwargs_strs.append(f"{key}={self.inputs[value].name}")
else:
kwargs_strs.append(f"{key}={value}")
# Handle output specifications using TensorSpec's __str__
if self.output_count == 1 and self.output_spec:
kwargs_strs.append(f"out={self.output_spec}")
elif self.output_count > 1 and self.output_specs:
for i, spec in enumerate(self.output_specs):
kwargs_strs.append(f"out_{i}={spec}")
base_str += f", kwargs={{{'; '.join(kwargs_strs)}}}"
base_str += ")"
return base_str
......@@ -116,14 +116,14 @@ class OpTest(BaseOperatorTest):
"""PyTorch multi_margin_loss implementation with device handling"""
return F.multi_margin_loss(*args, **kwargs)
# def infinicore_operator(self, *args, **kwargs):
# """InfiniCore multi_margin_loss implementation"""
# return None
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore multi_margin_loss implementation"""
return None
def main():
"""Main entry point"""
runner = GenericTestRunner(MultiMarginLossOpTest)
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment