Unverified Commit 12cde8eb authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #788 from InfiniTensor/issue/787

issue/787 - Split run ops test logic and fix kwargs name in report
parents 62fe6999 7aece930
from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult
from .entities import TestCase
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
add_common_test_args,
......@@ -9,35 +9,44 @@ from .config import (
)
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .executor import TestExecutor
from .utils.compare_utils import (
compare_results,
create_test_comparator,
debug,
get_tolerance,
)
from .utils.json_utils import save_json_report
from .utils.tensor_utils import (
infinicore_tensor_from_torch,
rearrange_tensor,
convert_infinicore_to_torch,
rearrange_tensor,
is_broadcast,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
is_integer_dtype,
)
__all__ = [
# Core types and classes
"BaseOperatorTest",
"CaseResult",
"GenericTestRunner",
"InfiniDeviceEnum",
"InfiniDeviceNames",
"OperatorResult",
"TensorInitializer",
"TensorSpec",
"TestCase",
"TestConfig",
"TestResult",
"TestExecutor",
"TestSummary",
"TestRunner",
"TestReporter",
"TestTiming",
# Core functions
"add_common_test_args",
"compare_results",
......@@ -50,6 +59,8 @@ __all__ = [
"get_tolerance",
"infinicore_tensor_from_torch",
"rearrange_tensor",
# Json utilites
"save_json_report",
# Utility functions
"to_infinicore_dtype",
"to_torch_dtype",
......
......@@ -8,15 +8,15 @@ import infinicore
import traceback
from abc import ABC, abstractmethod
from .test_case import TestCase, TestResult
from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .utils.tensor_utils import (
clone_torch_tensor,
create_test_comparator,
infinicore_tensor_from_torch,
)
from .utils.compare_utils import create_test_comparator
from .benchmark import BenchmarkUtils
......@@ -84,7 +84,7 @@ class TestRunner:
try:
print(f"{test_case}")
# Execute test and get TestResult object
# Execute test and get CaseResult object
test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
......@@ -118,8 +118,8 @@ class TestRunner:
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
# Create a failed CaseResult
failed_result = CaseResult(
success=False,
return_code=-1,
error_message=str(e),
......@@ -400,12 +400,12 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
TestResult: Test result object containing status and timing information
CaseResult: Test case result object containing status and timing information
"""
device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
# Initialize test case result
test_result = CaseResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
......
......@@ -5,7 +5,7 @@ Benchmarking utilities for the InfiniCore testing framework
import time
import torch
import infinicore
from .utils import synchronize_device
from .utils.tensor_utils import synchronize_device
class BenchmarkUtils:
......
......@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from .tensor import TensorSpec
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
......
import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .results import OperatorResult, TestSummary
@contextmanager
def capture_output():
"""Context manager: captures stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield new_out, new_err
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestExecutor:
def execute(self, file_path) -> OperatorResult:
result = OperatorResult(name=file_path.stem)
try:
# 1. Dynamically import the module
module = self._import_module(file_path)
# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
raise ImportError("No BaseOperatorTest subclass found")
test_instance = test_class()
runner_class = module.GenericTestRunner
runner = runner_class(test_instance.__class__)
# 4. Execute and capture output
with capture_output() as (out, err):
success, internal_runner = runner.run()
# 5. Populate results
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()
# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
except Exception as e:
result.success = False
result.error_message = str(e)
result.stderr += f"\nExecutor Error: {str(e)}"
result.return_code = -1
return result
def _import_module(self, path):
module_name = f"op_test_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
raise ImportError(f"Could not load spec from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _find_test_class(self, module):
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and hasattr(attr, "__bases__"):
# Simple check for base class name
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None
import json
import os
from datetime import datetime
from typing import List, Dict, Any, Union
from dataclasses import is_dataclass
from .base import TensorSpec
from .devices import InfiniDeviceEnum
class TestReporter:
"""
Handles report generation and file saving for test results.
"""
@staticmethod
def prepare_report_entry(
op_name: str,
test_cases: List[Any],
args: Any,
op_paths: Dict[str, str],
results_list: List[Any]
) -> List[Dict[str, Any]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map = {}
if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict):
results_map = results_list
else:
results_map = {0: results_list} if results_list else {}
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries: Dict[int, Dict[str, Any]] = {}
# 3. Iterate Test Cases
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
dev_str = device_id_map.get(dev_id, str(dev_id))
grouped_entries[dev_id] = {
"operator": op_name,
"device": dev_str,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": []
}
# --- B. Build Kwargs ---
display_kwargs = {}
# B1. Process existing kwargs
for k, v in tc.kwargs.items():
# Handle Inplace: "out": index -> "out": "input_name"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
display_kwargs[k] = tc.inputs[v].name
else:
display_kwargs[k] = f"Invalid_Index_{v}"
else:
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
# B2. Inject Outputs into Kwargs
if hasattr(tc, "output_specs") and tc.output_specs:
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
elif tc.output_spec:
if "out" not in display_kwargs:
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
# --- C. Build Test Case Dictionary ---
case_data = {
"description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
# --- D. Inject Result ---
if res:
case_data["result"] = TestReporter._fmt_result(res)
else:
case_data["result"] = {"status": {"success": False, "error": "No result"}}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
@staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define indentation levels for cleaner code
indent_4 = ' ' * 4
indent_8 = ' ' * 8
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{indent_4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if key == "testcases" and isinstance(val, list):
f.write(f'{indent_8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{indent_12}{{\n")
case_keys = list(case_item.keys())
for k_idx, c_key in enumerate(case_keys):
c_val = case_item[c_key]
# [Logic A] Skip fields we merged manually after 'kwargs'
if c_key in ["comparison_target", "tolerance"]:
continue
# Check comma for standard logic (might be overridden below)
c_comma = "," if k_idx < len(case_keys) - 1 else ""
# [Logic B] Handle 'kwargs' + Grouped Fields
if c_key == "kwargs":
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
line_comma = "," if remaining_keys else ""
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if c_key == "inputs" and isinstance(c_val, list):
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else:
c_val_str = json.dumps(c_val, ensure_ascii=False)
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
close_comma = "," if c_idx < len(val) - 1 else ""
f.write(f"{indent_12}}}{close_comma}\n")
f.write(f"{indent_8}]{comma}\n")
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else:
k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(val, ensure_ascii=False)
f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1:
f.write(f"{indent_4}}},\n")
else:
f.write(f"{indent_4}}}\n")
f.write("]\n")
print(f" ✅ Saved (Structure Matched).")
except Exception as e:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}")
# --- Internal Helpers ---
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = '{' if is_dict else '['
close_char = '}' if is_dict else ']'
f.write(f'{indent}"{key}": {open_char}')
# Normalize items for iteration
if is_dict:
items = list(value.items())
else:
items = value # List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
# Format individual item string
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = (i == len(items) - 1)
item_comma = "" if is_last else ", "
# Predict new length: current + item + comma
if current_len + len(item_str) + len(item_comma) > 180:
# Wrap to new line
f.write(f'\n{sub_indent}')
current_len = len(sub_indent)
f.write(f'{item_str}{item_comma}')
current_len += len(item_str) + len(item_comma)
f.write(f'{close_char}{close_comma}\n')
@staticmethod
def _spec_to_dict(s):
return {
"name": getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
@staticmethod
def _fmt_result(res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
from typing import List, Dict, Any
from dataclasses import dataclass, is_dataclass, field
from .devices import InfiniDeviceEnum
from .tensor import TensorSpec
from .utils.json_utils import save_json_report
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
class TestSummary:
"""
Test Summary class:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
"""
def __init__(self, verbose=False, bench_mode=None):
self.verbose = verbose
self.bench_mode = bench_mode
self.report_entries = [] # Cache for JSON report
# =========================================================
# Part 1: Result Aggregation
# =========================================================
def process_operator_result(self, op_result, sub_results: List):
"""
Updates the OperatorResult object in-place.
"""
if not sub_results:
op_result.return_code = -1
return
# 1. Analyze Return Code (Status)
if op_result.success:
op_result.return_code = 0
else:
has_failures = any(r.return_code == -1 for r in sub_results)
has_partial = any(r.return_code == -3 for r in sub_results)
has_skipped = any(r.return_code == -2 for r in sub_results)
if has_failures:
op_result.return_code = -1
elif has_partial:
op_result.return_code = -3
elif has_skipped:
op_result.return_code = -2
else:
op_result.return_code = -1
# 2. Extract Timing (Aggregation)
t = op_result.timing
t.torch_host = sum(r.torch_host_time for r in sub_results)
t.torch_device = sum(r.torch_device_time for r in sub_results)
t.infini_host = sum(r.infini_host_time for r in sub_results)
t.infini_device = sum(r.infini_device_time for r in sub_results)
t.operators_tested = len(sub_results)
# =========================================================
# Part 2: Console Output (View)
# =========================================================
def list_tests(self, discoverer):
ops_dir = discoverer.ops_dir
operators = discoverer.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
raw_files = discoverer.get_raw_python_files()
if raw_files:
print(
f"\n💡 Debug Hint: Found Python files but they are not valid tests:"
)
print(f" {raw_files}")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result):
print(
f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})"
)
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or self.verbose:
print("-" * 40)
def print_summary(self, results, cumulative_timing, ops_dir, total_expected=0):
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped:
print(f"Skipped: {len(skipped)}")
if partial:
print(f"Partial: {len(partial)}")
# 1. Benchmark
if cumulative_timing:
self._print_timing(cumulative_timing)
# 2. Lists
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Verdict
if total > 0:
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if self.verbose and failed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
for r in failed[:3]:
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
def _print_timing(self, t):
print(f"{'-'*40}")
if hasattr(t, "operators_tested") and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY ({t.operators_tested} cases):")
if self.bench_mode in ["host", "both"]:
print(f" [Host] PyTorch: {t.torch_host:10.3f} ms")
print(f" [Host] InfiniCore: {t.infini_host:10.3f} ms")
if self.bench_mode in ["device", "both"]:
print(f" [Device] PyTorch: {t.torch_device:10.3f} ms")
print(f" [Device] InfiniCore: {t.infini_device:10.3f} ms")
print(f"{'-'*40}")
def _print_op_list(self, title, result_list):
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
# =========================================================
# Part 3: Report Generation
# =========================================================
def collect_report_entry(self, op_name, test_cases, args, op_paths, results_list):
"""
Prepares the data and adds it to the internal list.
"""
entry = self._prepare_entry_logic(
op_name, test_cases, args, op_paths, results_list
)
self.report_entries.extend(entry)
def save_report(self, save_path):
"""
Delegates the actual writing to save_json_report.
"""
if not self.report_entries:
return
# Call the external utility
save_json_report(save_path, self.report_entries)
def _prepare_entry_logic(self, op_name, test_cases, args, op_paths, results_list):
"""
Combines static test case info with dynamic execution results.
Refactored to reduce duplication.
"""
# 1. Normalize results
results_map = (
results_list
if isinstance(results_list, dict)
else {i: res for i, res in enumerate(results_list or [])}
)
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries = {}
# Cache device enum map
device_id_map = {
v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")
}
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
grouped_entries[dev_id] = {
"operator": op_name,
"device": device_id_map.get(dev_id, str(dev_id)),
"torch_op": op_paths.get("torch", "unknown"),
"infinicore_op": op_paths.get("infinicore", "unknown"),
"args": global_args,
"testcases": [],
}
# --- B. Helpers for Spec Processing ---
def process_spec(spec, default_name):
final_name = self._resolve_name(spec, default_name)
# Call internal method (no need for external converters file)
return self._spec_to_dict(spec, name=final_name)
# --- C. Build Inputs ---
processed_inputs = [
process_spec(inp, f"in_{i}") for i, inp in enumerate(tc.inputs)
]
# --- D. Build Kwargs ---
display_kwargs = {}
for k, v in tc.kwargs.items():
if k == "out" and isinstance(v, int):
# Handle Inplace Index
if 0 <= v < len(tc.inputs):
display_kwargs[k] = self._resolve_name(tc.inputs[v], f"in_{v}")
else:
display_kwargs[k] = f"Invalid_Index_{v}"
elif isinstance(v, TensorSpec):
display_kwargs[k] = process_spec(v, v.name)
else:
display_kwargs[k] = v
# --- E. Inject Outputs ---
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = process_spec(spec, f"out_{i}")
elif tc.output_spec and "out" not in display_kwargs:
display_kwargs["out"] = process_spec(tc.output_spec, "out")
# --- F. Assemble Case Data ---
case_data = {
"description": tc.description,
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
"result": (
self._fmt_result(res)
if res
else {"status": {"success": False, "error": "No result"}}
),
}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
# --- Internal Helpers ---
def _resolve_name(self, obj, default_name):
return getattr(obj, "name", None) or default_name
def _spec_to_dict(self, s, name=None):
return {
"name": name if name else getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
def _fmt_result(self, res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
......@@ -7,7 +7,7 @@ import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
from .results import TestSummary
class GenericTestRunner:
......@@ -89,7 +89,8 @@ class GenericTestRunner:
op_paths = {"torch": t_path, "infinicore": i_path}
# 2. Generate Report Entries
entries = TestReporter.prepare_report_entry(
test_summary = TestSummary()
entries = test_summary.collect_report_entry(
op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases,
args=self.args,
......@@ -98,7 +99,7 @@ class GenericTestRunner:
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, entries)
test_summary.save_report(self.args.save)
except Exception as e:
import traceback
......
......@@ -3,7 +3,7 @@ import math
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype, is_complex_dtype
from .utils.tensor_utils import is_integer_dtype, is_complex_dtype
class TensorInitializer:
......
import torch
import time
import infinicore
import numpy as np
from .datatypes import to_infinicore_dtype, to_torch_dtype
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
return diff_indices
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Get tolerance settings based on data type
"""
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
import sys
from ..datatypes import to_torch_dtype
from .tensor_utils import (
convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
)
def compare_results(
......@@ -351,89 +189,104 @@ def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
return compare_test_results
def rearrange_tensor(tensor, new_strides):
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
Get tolerance settings based on data type
"""
import torch
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
import torch
import sys
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
return new_tensor
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
Args:
strides: Tensor strides or None
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
return diff_indices
import json
import os
from datetime import datetime
def save_json_report(save_path, total_results):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define Indentation
I4, I8, I12, I16, I20 = " " * 4, " " * 8, " " * 12, " " * 16, " " * 20
print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition
def _to_json(obj):
return json.dumps(obj, ensure_ascii=False)
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{I4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# Special handling for 'testcases' list
if key == "testcases" and isinstance(val, list):
f.write(f'{I8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{I12}{{\n")
case_keys = list(case_item.keys())
# Filter out keys that we handle specially at the end
standard_keys = [
k
for k in case_keys
if k not in ["comparison_target", "tolerance"]
]
for k_idx, c_key in enumerate(standard_keys):
c_val = case_item[c_key]
# Determine comma logic
c_comma = (
","
if k_idx < len(standard_keys) - 1
or "comparison_target" in case_item
else ""
)
if c_key in ["kwargs", "inputs"]:
_write_field(
f, c_key, c_val, I16, I20, close_comma=c_comma
)
else:
f.write(
f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n'
)
# Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item:
cmp = _to_json(case_item.get("comparison_target"))
tol = _to_json(case_item.get("tolerance"))
f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
)
close_case = "," if c_idx < len(val) - 1 else ""
f.write(f"{I12}}}{close_case}\n")
f.write(f"{I8}]{comma}\n")
else:
# Standard top-level fields
f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n")
f.write("]\n")
print(f" ✅ Saved.")
except Exception as e:
import traceback
traceback.print_exc()
print(f" ❌ Save failed: {e}")
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Internal Helper: Write a JSON field with wrapping.
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = "{" if is_dict else "["
close_char = "}" if is_dict else "]"
f.write(f'{indent}"{key}": {open_char}')
if is_dict:
items = list(value.items())
else:
items = value
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = i == len(items) - 1
item_comma = "" if is_last else ", "
if current_len + len(item_str) + len(item_comma) > 180:
f.write(f"\n{sub_indent}")
current_len = len(sub_indent)
f.write(f"{item_str}{item_comma}")
current_len += len(item_str) + len(item_comma)
f.write(f"{close_char}{close_comma}\n")
import torch
import infinicore
from ..datatypes import to_infinicore_dtype, to_torch_dtype
# =================================================================
# Device & Synchronization
# =================================================================
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
# =================================================================
# Tensor Operations & Conversions
# =================================================================
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
def rearrange_tensor(tensor, new_strides):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
"""
import torch
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
# =================================================================
# Type Checks (Moved here to avoid circular imports in check.py)
# =================================================================
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -76,7 +77,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer
from framework.utils import (
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
to_torch_dtype,
......
......@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import TestResult
from framework.utils import (
from framework.results import CaseResult
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
)
......@@ -268,8 +268,8 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......@@ -283,8 +283,8 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -180,7 +181,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -193,7 +194,7 @@ class OpTest(BaseOperatorTest):
)
for spec in output_specs:
if isinstance(spec, TensorSpec) and spec.strides is not None:
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -122,7 +123,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -135,7 +136,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.output_spec, TensorSpec)
and test_case.output_spec.strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment