Unverified Commit 12cde8eb authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #788 from InfiniTensor/issue/787

issue/787 - Split run ops test logic and fix kwargs name in report
parents 62fe6999 7aece930
from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult
from .entities import TestCase
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
add_common_test_args,
......@@ -9,35 +9,44 @@ from .config import (
)
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .executor import TestExecutor
from .utils.compare_utils import (
compare_results,
create_test_comparator,
debug,
get_tolerance,
)
from .utils.json_utils import save_json_report
from .utils.tensor_utils import (
infinicore_tensor_from_torch,
rearrange_tensor,
convert_infinicore_to_torch,
rearrange_tensor,
is_broadcast,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
is_integer_dtype,
)
__all__ = [
# Core types and classes
"BaseOperatorTest",
"CaseResult",
"GenericTestRunner",
"InfiniDeviceEnum",
"InfiniDeviceNames",
"OperatorResult",
"TensorInitializer",
"TensorSpec",
"TestCase",
"TestConfig",
"TestResult",
"TestExecutor",
"TestSummary",
"TestRunner",
"TestReporter",
"TestTiming",
# Core functions
"add_common_test_args",
"compare_results",
......@@ -50,6 +59,8 @@ __all__ = [
"get_tolerance",
"infinicore_tensor_from_torch",
"rearrange_tensor",
# Json utilites
"save_json_report",
# Utility functions
"to_infinicore_dtype",
"to_torch_dtype",
......
......@@ -8,15 +8,15 @@ import infinicore
import traceback
from abc import ABC, abstractmethod
from .test_case import TestCase, TestResult
from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .utils.tensor_utils import (
clone_torch_tensor,
create_test_comparator,
infinicore_tensor_from_torch,
)
from .utils.compare_utils import create_test_comparator
from .benchmark import BenchmarkUtils
......@@ -84,7 +84,7 @@ class TestRunner:
try:
print(f"{test_case}")
# Execute test and get TestResult object
# Execute test and get CaseResult object
test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
......@@ -118,8 +118,8 @@ class TestRunner:
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
# Create a failed CaseResult
failed_result = CaseResult(
success=False,
return_code=-1,
error_message=str(e),
......@@ -400,12 +400,12 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
TestResult: Test result object containing status and timing information
CaseResult: Test case result object containing status and timing information
"""
device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
# Initialize test case result
test_result = CaseResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
......
......@@ -5,7 +5,7 @@ Benchmarking utilities for the InfiniCore testing framework
import time
import torch
import infinicore
from .utils import synchronize_device
from .utils.tensor_utils import synchronize_device
class BenchmarkUtils:
......
......@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from .tensor import TensorSpec
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
......
import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .results import OperatorResult, TestSummary
@contextmanager
def capture_output():
"""Context manager: captures stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield new_out, new_err
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestExecutor:
def execute(self, file_path) -> OperatorResult:
result = OperatorResult(name=file_path.stem)
try:
# 1. Dynamically import the module
module = self._import_module(file_path)
# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
raise ImportError("No BaseOperatorTest subclass found")
test_instance = test_class()
runner_class = module.GenericTestRunner
runner = runner_class(test_instance.__class__)
# 4. Execute and capture output
with capture_output() as (out, err):
success, internal_runner = runner.run()
# 5. Populate results
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()
# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
except Exception as e:
result.success = False
result.error_message = str(e)
result.stderr += f"\nExecutor Error: {str(e)}"
result.return_code = -1
return result
def _import_module(self, path):
module_name = f"op_test_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
raise ImportError(f"Could not load spec from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _find_test_class(self, module):
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and hasattr(attr, "__bases__"):
# Simple check for base class name
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None
import json
import os
from datetime import datetime
from typing import List, Dict, Any, Union
from dataclasses import is_dataclass
from .base import TensorSpec
from .devices import InfiniDeviceEnum
class TestReporter:
"""
Handles report generation and file saving for test results.
"""
@staticmethod
def prepare_report_entry(
op_name: str,
test_cases: List[Any],
args: Any,
op_paths: Dict[str, str],
results_list: List[Any]
) -> List[Dict[str, Any]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map = {}
if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict):
results_map = results_list
else:
results_map = {0: results_list} if results_list else {}
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries: Dict[int, Dict[str, Any]] = {}
# 3. Iterate Test Cases
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
dev_str = device_id_map.get(dev_id, str(dev_id))
grouped_entries[dev_id] = {
"operator": op_name,
"device": dev_str,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": []
}
# --- B. Build Kwargs ---
display_kwargs = {}
# B1. Process existing kwargs
for k, v in tc.kwargs.items():
# Handle Inplace: "out": index -> "out": "input_name"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
display_kwargs[k] = tc.inputs[v].name
else:
display_kwargs[k] = f"Invalid_Index_{v}"
else:
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
# B2. Inject Outputs into Kwargs
if hasattr(tc, "output_specs") and tc.output_specs:
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
elif tc.output_spec:
if "out" not in display_kwargs:
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
# --- C. Build Test Case Dictionary ---
case_data = {
"description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
# --- D. Inject Result ---
if res:
case_data["result"] = TestReporter._fmt_result(res)
else:
case_data["result"] = {"status": {"success": False, "error": "No result"}}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
@staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define indentation levels for cleaner code
indent_4 = ' ' * 4
indent_8 = ' ' * 8
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{indent_4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if key == "testcases" and isinstance(val, list):
f.write(f'{indent_8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{indent_12}{{\n")
case_keys = list(case_item.keys())
for k_idx, c_key in enumerate(case_keys):
c_val = case_item[c_key]
# [Logic A] Skip fields we merged manually after 'kwargs'
if c_key in ["comparison_target", "tolerance"]:
continue
# Check comma for standard logic (might be overridden below)
c_comma = "," if k_idx < len(case_keys) - 1 else ""
# [Logic B] Handle 'kwargs' + Grouped Fields
if c_key == "kwargs":
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
line_comma = "," if remaining_keys else ""
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if c_key == "inputs" and isinstance(c_val, list):
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else:
c_val_str = json.dumps(c_val, ensure_ascii=False)
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
close_comma = "," if c_idx < len(val) - 1 else ""
f.write(f"{indent_12}}}{close_comma}\n")
f.write(f"{indent_8}]{comma}\n")
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else:
k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(val, ensure_ascii=False)
f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1:
f.write(f"{indent_4}}},\n")
else:
f.write(f"{indent_4}}}\n")
f.write("]\n")
print(f" ✅ Saved (Structure Matched).")
except Exception as e:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}")
# --- Internal Helpers ---
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = '{' if is_dict else '['
close_char = '}' if is_dict else ']'
f.write(f'{indent}"{key}": {open_char}')
# Normalize items for iteration
if is_dict:
items = list(value.items())
else:
items = value # List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
# Format individual item string
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = (i == len(items) - 1)
item_comma = "" if is_last else ", "
# Predict new length: current + item + comma
if current_len + len(item_str) + len(item_comma) > 180:
# Wrap to new line
f.write(f'\n{sub_indent}')
current_len = len(sub_indent)
f.write(f'{item_str}{item_comma}')
current_len += len(item_str) + len(item_comma)
f.write(f'{close_char}{close_comma}\n')
@staticmethod
def _spec_to_dict(s):
return {
"name": getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
@staticmethod
def _fmt_result(res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
from typing import List, Dict, Any
from dataclasses import dataclass, is_dataclass, field
from .devices import InfiniDeviceEnum
from .tensor import TensorSpec
from .utils.json_utils import save_json_report
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
class TestSummary:
"""
Test Summary class:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
"""
def __init__(self, verbose=False, bench_mode=None):
self.verbose = verbose
self.bench_mode = bench_mode
self.report_entries = [] # Cache for JSON report
# =========================================================
# Part 1: Result Aggregation
# =========================================================
def process_operator_result(self, op_result, sub_results: List):
"""
Updates the OperatorResult object in-place.
"""
if not sub_results:
op_result.return_code = -1
return
# 1. Analyze Return Code (Status)
if op_result.success:
op_result.return_code = 0
else:
has_failures = any(r.return_code == -1 for r in sub_results)
has_partial = any(r.return_code == -3 for r in sub_results)
has_skipped = any(r.return_code == -2 for r in sub_results)
if has_failures:
op_result.return_code = -1
elif has_partial:
op_result.return_code = -3
elif has_skipped:
op_result.return_code = -2
else:
op_result.return_code = -1
# 2. Extract Timing (Aggregation)
t = op_result.timing
t.torch_host = sum(r.torch_host_time for r in sub_results)
t.torch_device = sum(r.torch_device_time for r in sub_results)
t.infini_host = sum(r.infini_host_time for r in sub_results)
t.infini_device = sum(r.infini_device_time for r in sub_results)
t.operators_tested = len(sub_results)
# =========================================================
# Part 2: Console Output (View)
# =========================================================
def list_tests(self, discoverer):
ops_dir = discoverer.ops_dir
operators = discoverer.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
raw_files = discoverer.get_raw_python_files()
if raw_files:
print(
f"\n💡 Debug Hint: Found Python files but they are not valid tests:"
)
print(f" {raw_files}")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result):
print(
f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})"
)
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or self.verbose:
print("-" * 40)
def print_summary(self, results, cumulative_timing, ops_dir, total_expected=0):
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped:
print(f"Skipped: {len(skipped)}")
if partial:
print(f"Partial: {len(partial)}")
# 1. Benchmark
if cumulative_timing:
self._print_timing(cumulative_timing)
# 2. Lists
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Verdict
if total > 0:
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if self.verbose and failed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
for r in failed[:3]:
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
def _print_timing(self, t):
print(f"{'-'*40}")
if hasattr(t, "operators_tested") and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY ({t.operators_tested} cases):")
if self.bench_mode in ["host", "both"]:
print(f" [Host] PyTorch: {t.torch_host:10.3f} ms")
print(f" [Host] InfiniCore: {t.infini_host:10.3f} ms")
if self.bench_mode in ["device", "both"]:
print(f" [Device] PyTorch: {t.torch_device:10.3f} ms")
print(f" [Device] InfiniCore: {t.infini_device:10.3f} ms")
print(f"{'-'*40}")
def _print_op_list(self, title, result_list):
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
# =========================================================
# Part 3: Report Generation
# =========================================================
def collect_report_entry(self, op_name, test_cases, args, op_paths, results_list):
"""
Prepares the data and adds it to the internal list.
"""
entry = self._prepare_entry_logic(
op_name, test_cases, args, op_paths, results_list
)
self.report_entries.extend(entry)
def save_report(self, save_path):
"""
Delegates the actual writing to save_json_report.
"""
if not self.report_entries:
return
# Call the external utility
save_json_report(save_path, self.report_entries)
def _prepare_entry_logic(self, op_name, test_cases, args, op_paths, results_list):
"""
Combines static test case info with dynamic execution results.
Refactored to reduce duplication.
"""
# 1. Normalize results
results_map = (
results_list
if isinstance(results_list, dict)
else {i: res for i, res in enumerate(results_list or [])}
)
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries = {}
# Cache device enum map
device_id_map = {
v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")
}
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
grouped_entries[dev_id] = {
"operator": op_name,
"device": device_id_map.get(dev_id, str(dev_id)),
"torch_op": op_paths.get("torch", "unknown"),
"infinicore_op": op_paths.get("infinicore", "unknown"),
"args": global_args,
"testcases": [],
}
# --- B. Helpers for Spec Processing ---
def process_spec(spec, default_name):
final_name = self._resolve_name(spec, default_name)
# Call internal method (no need for external converters file)
return self._spec_to_dict(spec, name=final_name)
# --- C. Build Inputs ---
processed_inputs = [
process_spec(inp, f"in_{i}") for i, inp in enumerate(tc.inputs)
]
# --- D. Build Kwargs ---
display_kwargs = {}
for k, v in tc.kwargs.items():
if k == "out" and isinstance(v, int):
# Handle Inplace Index
if 0 <= v < len(tc.inputs):
display_kwargs[k] = self._resolve_name(tc.inputs[v], f"in_{v}")
else:
display_kwargs[k] = f"Invalid_Index_{v}"
elif isinstance(v, TensorSpec):
display_kwargs[k] = process_spec(v, v.name)
else:
display_kwargs[k] = v
# --- E. Inject Outputs ---
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = process_spec(spec, f"out_{i}")
elif tc.output_spec and "out" not in display_kwargs:
display_kwargs["out"] = process_spec(tc.output_spec, "out")
# --- F. Assemble Case Data ---
case_data = {
"description": tc.description,
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
"result": (
self._fmt_result(res)
if res
else {"status": {"success": False, "error": "No result"}}
),
}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
# --- Internal Helpers ---
def _resolve_name(self, obj, default_name):
return getattr(obj, "name", None) or default_name
def _spec_to_dict(self, s, name=None):
return {
"name": name if name else getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
def _fmt_result(self, res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
......@@ -7,7 +7,7 @@ import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
from .results import TestSummary
class GenericTestRunner:
......@@ -89,7 +89,8 @@ class GenericTestRunner:
op_paths = {"torch": t_path, "infinicore": i_path}
# 2. Generate Report Entries
entries = TestReporter.prepare_report_entry(
test_summary = TestSummary()
entries = test_summary.collect_report_entry(
op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases,
args=self.args,
......@@ -98,7 +99,7 @@ class GenericTestRunner:
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, entries)
test_summary.save_report(self.args.save)
except Exception as e:
import traceback
......
......@@ -3,7 +3,7 @@ import math
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype, is_complex_dtype
from .utils.tensor_utils import is_integer_dtype, is_complex_dtype
class TensorInitializer:
......
import torch
import time
import infinicore
import numpy as np
from .datatypes import to_infinicore_dtype, to_torch_dtype
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
return diff_indices
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Get tolerance settings based on data type
"""
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
import sys
from ..datatypes import to_torch_dtype
from .tensor_utils import (
convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
)
def compare_results(
......@@ -351,89 +189,104 @@ def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
return compare_test_results
def rearrange_tensor(tensor, new_strides):
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
Get tolerance settings based on data type
"""
import torch
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
return new_tensor
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
Args:
strides: Tensor strides or None
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
return diff_indices
import json
import os
from datetime import datetime
def save_json_report(save_path, total_results):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define Indentation
I4, I8, I12, I16, I20 = " " * 4, " " * 8, " " * 12, " " * 16, " " * 20
print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition
def _to_json(obj):
return json.dumps(obj, ensure_ascii=False)
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{I4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# Special handling for 'testcases' list
if key == "testcases" and isinstance(val, list):
f.write(f'{I8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{I12}{{\n")
case_keys = list(case_item.keys())
# Filter out keys that we handle specially at the end
standard_keys = [
k
for k in case_keys
if k not in ["comparison_target", "tolerance"]
]
for k_idx, c_key in enumerate(standard_keys):
c_val = case_item[c_key]
# Determine comma logic
c_comma = (
","
if k_idx < len(standard_keys) - 1
or "comparison_target" in case_item
else ""
)
if c_key in ["kwargs", "inputs"]:
_write_field(
f, c_key, c_val, I16, I20, close_comma=c_comma
)
else:
f.write(
f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n'
)
# Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item:
cmp = _to_json(case_item.get("comparison_target"))
tol = _to_json(case_item.get("tolerance"))
f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
)
close_case = "," if c_idx < len(val) - 1 else ""
f.write(f"{I12}}}{close_case}\n")
f.write(f"{I8}]{comma}\n")
else:
# Standard top-level fields
f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n")
f.write("]\n")
print(f" ✅ Saved.")
except Exception as e:
import traceback
traceback.print_exc()
print(f" ❌ Save failed: {e}")
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Internal Helper: Write a JSON field with wrapping.
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = "{" if is_dict else "["
close_char = "}" if is_dict else "]"
f.write(f'{indent}"{key}": {open_char}')
if is_dict:
items = list(value.items())
else:
items = value
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = i == len(items) - 1
item_comma = "" if is_last else ", "
if current_len + len(item_str) + len(item_comma) > 180:
f.write(f"\n{sub_indent}")
current_len = len(sub_indent)
f.write(f"{item_str}{item_comma}")
current_len += len(item_str) + len(item_comma)
f.write(f"{close_char}{close_comma}\n")
import torch
import infinicore
from ..datatypes import to_infinicore_dtype, to_torch_dtype
# =================================================================
# Device & Synchronization
# =================================================================
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
# =================================================================
# Tensor Operations & Conversions
# =================================================================
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
def rearrange_tensor(tensor, new_strides):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
"""
import torch
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
# =================================================================
# Type Checks (Moved here to avoid circular imports in check.py)
# =================================================================
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -76,7 +77,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer
from framework.utils import (
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
to_torch_dtype,
......
......@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import TestResult
from framework.utils import (
from framework.results import CaseResult
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
)
......@@ -268,8 +268,8 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......@@ -283,8 +283,8 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -180,7 +181,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -193,7 +194,7 @@ class OpTest(BaseOperatorTest):
)
for spec in output_specs:
if isinstance(spec, TensorSpec) and spec.strides is not None:
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......@@ -122,7 +123,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -135,7 +136,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.output_spec, TensorSpec)
and test_case.output_spec.strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
import os
import sys
import argparse
import traceback
from pathlib import Path
import importlib.util
# Import components from the unified framework package
from framework.executor import TestExecutor
from framework.results import TestSummary, TestTiming
from framework import get_hardware_args_group, add_common_test_args
def find_ops_directory(location=None):
"""
Find the ops directory by searching from location upwards.
Args:
location: Starting directory for search (default: current file's parent)
Returns:
Path: Path to ops directory or None if not found
"""
if location is None:
location = Path(__file__).parent / "ops"
ops_dir = location.resolve()
if ops_dir.exists() and any(ops_dir.glob("*.py")):
return ops_dir
return None
def get_available_operators(ops_dir):
"""
Get list of available operators from ops directory.
Args:
ops_dir: Path to ops directory
Returns:
List of operator names
"""
if not ops_dir or not ops_dir.exists():
return []
test_files = list(ops_dir.glob("*.py"))
current_script = Path(__file__).name
test_files = [f for f in test_files if f.name != current_script]
operators = []
for test_file in test_files:
try:
with open(test_file, "r", encoding="utf-8") as f:
content = f.read()
if "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
):
operators.append(test_file.stem)
except:
continue
return sorted(operators)
def import_operator_test(test_file_path):
"""
Import an operator test module and return the test class instance.
Args:
test_file_path: Path to the test file
class TestDiscoverer:
def __init__(self, ops_dir_path=None):
self.ops_dir = self._resolve_dir(ops_dir_path)
def _resolve_dir(self, path):
if path:
p = Path(path)
if p.exists():
return p
# Default fallback logic: 'ops' directory under the parent of the current file's parent.
# Note: Since this file is in 'infinicore/', we look at parent.
# It is recommended to pass an explicit path in run.py.
fallback = Path(__file__).parent / "ops"
return fallback if fallback.exists() else None
def get_available_operators(self):
"""Returns a list of names of all available operators."""
if not self.ops_dir:
return []
files = self.scan()
return sorted([f.stem for f in files])
def get_raw_python_files(self):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if not self.ops_dir or not self.ops_dir.exists():
return []
files = list(self.ops_dir.glob("*.py"))
# Exclude run.py itself and __init__.py
return [
f.name for f in files if f.name != "run.py" and not f.name.startswith("__")
]
Returns:
tuple: (success, test_instance_or_error)
"""
try:
# Create a unique module name
module_name = f"op_test_{test_file_path.stem}"
# Load the module from file
spec = importlib.util.spec_from_file_location(module_name, test_file_path)
if spec is None or spec.loader is None:
return False, f"Could not load module from {test_file_path}"
module = importlib.util.module_from_spec(spec)
# Add the module to sys.modules
sys.modules[module_name] = module
# Execute the module
spec.loader.exec_module(module)
# Find the test class (usually named OpTest)
test_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and hasattr(attr, "__bases__")
and any("BaseOperatorTest" in str(base) for base in attr.__bases__)
):
test_class = attr
break
if test_class is None:
return False, f"No test class found in {test_file_path}"
# Create an instance
test_instance = test_class()
return True, test_instance
except Exception as e:
return False, f"Error importing {test_file_path}: {str(e)}"
def run_all_op_tests(
ops_dir=None,
specific_ops=None,
bench=False,
bench_mode="both",
verbose=False,
debug=False,
):
"""
Run all operator test scripts in the ops directory using direct import.
def scan(self, specific_ops=None):
"""Scans and returns a list of Path objects that meet the criteria."""
if not self.ops_dir or not self.ops_dir.exists():
return []
Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
specific_ops (list, optional): List of specific operator names to test.
bench (bool): Whether benchmarking is enabled
bench_mode (str): Benchmark mode - "host", "device", or "both"
verbose (bool): Whether verbose mode is enabled
# 1. Find all .py files
files = list(self.ops_dir.glob("*.py"))
Returns:
dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
"""
if ops_dir is None:
ops_dir = find_ops_directory()
else:
ops_dir = Path(ops_dir)
target_ops_set = set(specific_ops) if specific_ops else None
if not ops_dir or not ops_dir.exists():
print(f"Error: Ops directory '{ops_dir}' does not exist.")
return {}
# 2. Filter out non-test files (via content check)
valid_files = []
for f in files:
# A. Basic Name Filtering
if f.name.startswith("_") or f.name == "run.py":
continue
print(f"Looking for test files in: {ops_dir}")
# B. Specific Ops Filtering
if target_ops_set and f.stem not in target_ops_set:
continue
# Find all Python test files
test_files = list(ops_dir.glob("*.py"))
# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if self._is_operator_test(f):
valid_files.append(f)
# Filter out this script itself and non-operator test files
current_script = Path(__file__).name
test_files = [f for f in test_files if f.name != current_script]
return valid_files
# Filter to include only files that look like operator tests
operator_test_files = []
for test_file in test_files:
def _is_operator_test(self, file_path):
"""Checks if the file content contains operator test characteristics."""
try:
with open(test_file, "r", encoding="utf-8") as f:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Look for characteristic patterns of operator tests
if "infinicore" in content and (
return "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
):
operator_test_files.append(test_file)
except Exception as e:
continue
# Filter for specific operators if requested
if specific_ops:
filtered_files = []
for test_file in operator_test_files:
test_name = test_file.stem.lower()
if any(op.lower() == test_name for op in specific_ops):
filtered_files.append(test_file)
operator_test_files = filtered_files
if not operator_test_files:
print(f"No operator test files found in {ops_dir}")
print(f"Available Python files: {[f.name for f in test_files]}")
return {}
print(f"Found {len(operator_test_files)} operator test files:")
for test_file in operator_test_files:
print(f" - {test_file.name}")
results = {}
cumulative_timing = {
"total_torch_host_time": 0.0,
"total_torch_device_time": 0.0,
"total_infinicore_host_time": 0.0,
"total_infinicore_device_time": 0.0,
"operators_tested": 0,
}
for test_file in operator_test_files:
test_name = test_file.stem
try:
# Import and run the test directly
success, test_instance_or_error = import_operator_test(test_file)
if not success:
print(f"💥 {test_name}: ERROR - {test_instance_or_error}")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_host_time": 0.0,
"torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": test_instance_or_error,
"test_runner": None,
"stdout": "",
"stderr": test_instance_or_error,
}
continue
# Get the test runner class from the module
test_module = sys.modules[f"op_test_{test_file.stem}"]
if not hasattr(test_module, "GenericTestRunner"):
print(f"💥 {test_name}: ERROR - No GenericTestRunner found")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_host_time": 0.0,
"torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": "No GenericTestRunner found",
"test_runner": None,
"stdout": "",
"stderr": "No GenericTestRunner found",
}
continue
# Create and run the test runner
test_runner_class = test_module.GenericTestRunner
runner_instance = test_runner_class(test_instance_or_error.__class__)
# Temporarily redirect stdout to capture output
from io import StringIO
stdout_capture = StringIO()
stderr_capture = StringIO()
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = stdout_capture
sys.stderr = stderr_capture
try:
# Run the test
test_success, test_runner = runner_instance.run()
# Get captured output
stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
# Restore stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Print the captured output
if stdout_output:
print(stdout_output.rstrip())
if stderr_output:
print("\nSTDERR:")
print(stderr_output.rstrip())
# Analyze test results
test_results = test_runner.get_test_results() if test_runner else []
# Determine overall test status
if test_success:
return_code = 0
status_icon = "✅"
status_text = "PASSED"
else:
# Check if there are any failed tests
has_failures = any(
result.return_code == -1 for result in test_results
)
has_partial = any(
result.return_code == -3 for result in test_results
)
has_skipped = any(
result.return_code == -2 for result in test_results
)
if has_failures:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
elif has_partial:
return_code = -3
status_icon = "⚠️"
status_text = "PARTIAL"
elif has_skipped:
return_code = -2
status_icon = "⏭️"
status_text = "SKIPPED"
else:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
# Calculate timing for all four metrics
torch_host_time = sum(result.torch_host_time for result in test_results)
torch_device_time = sum(
result.torch_device_time for result in test_results
)
infini_host_time = sum(
result.infini_host_time for result in test_results
)
infini_device_time = sum(
result.infini_device_time for result in test_results
)
results[test_name] = {
"success": test_success,
"return_code": return_code,
"torch_host_time": torch_host_time,
"torch_device_time": torch_device_time,
"infini_host_time": infini_host_time,
"infini_device_time": infini_device_time,
"error_message": "",
"test_runner": test_runner,
"stdout": stdout_output,
"stderr": stderr_output,
}
print(
f"{status_icon} {test_name}: {status_text} (return code: {return_code})"
)
# Extract benchmark timing if in bench mode
if bench and test_success and return_code == 0:
cumulative_timing["total_torch_host_time"] += torch_host_time
cumulative_timing["total_torch_device_time"] += torch_device_time
cumulative_timing["total_infinicore_host_time"] += infini_host_time
cumulative_timing[
"total_infinicore_device_time"
] += infini_device_time
cumulative_timing["operators_tested"] += 1
except Exception as e:
# Restore stdout/stderr in case of exception
sys.stdout = old_stdout
sys.stderr = old_stderr
raise e
# In verbose mode, stop execution on first failure
if verbose and not test_success and return_code != 0:
break
except Exception as e:
print(f"💥 {test_name}: ERROR - {str(e)}")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_host_time": 0.0,
"torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": str(e),
"test_runner": None,
"stdout": "",
"stderr": str(e),
}
# In verbose mode, stop execution on any exception
if verbose:
print(f"\n{'!'*60}")
print(
f"VERBOSE MODE: Stopping execution due to exception in {test_name}"
)
print(f"{'!'*60}")
break
if debug:
traceback.print_exc()
break
return results, cumulative_timing
def print_summary(
results,
verbose=False,
total_expected_tests=0,
cumulative_timing=None,
bench_mode="both",
):
"""Print a comprehensive summary of test results including benchmark data."""
print(f"\n{'='*80}")
print("CUMULATIVE TEST SUMMARY")
print(f"{'='*80}")
if not results:
print("No tests were run.")
return False
# Count different types of results
passed = 0
failed = 0
skipped = 0
partial = 0
passed_operators = [] # Store passed operator names
failed_operators = [] # Store failed operator names
skipped_operators = [] # Store skipped operator names
partial_operators = [] # Store partial operator names
for test_name, result_data in results.items():
return_code = result_data["return_code"]
if return_code == 0:
passed += 1
passed_operators.append(test_name)
elif return_code == -2: # Special code for skipped tests
skipped += 1
skipped_operators.append(test_name)
elif return_code == -3: # Special code for partial tests
partial += 1
partial_operators.append(test_name)
else:
failed += 1
failed_operators.append(test_name)
total = len(results)
print(f"Total tests run: {total}")
if total_expected_tests > 0 and total < total_expected_tests:
print(f"Total tests expected: {total_expected_tests}")
print(f"Tests not executed: {total_expected_tests - total}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if skipped > 0:
print(f"Skipped: {skipped}")
if partial > 0:
print(f"Partial: {partial}")
# Print benchmark summary if cumulative_timing data is available
if cumulative_timing and cumulative_timing["operators_tested"] > 0:
print(f"{'-'*40}")
print("BENCHMARK SUMMARY:")
print(f" Operators Tested: {cumulative_timing['operators_tested']}")
# Display timing based on bench_mode
if bench_mode in ["host", "both"]:
print(
f" PyTorch Host Total Time: {cumulative_timing['total_torch_host_time']:12.3f} ms"
)
print(
f" InfiniCore Host Total Time: {cumulative_timing['total_infinicore_host_time']:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {cumulative_timing['total_torch_device_time']:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {cumulative_timing['total_infinicore_device_time']:12.3f} ms"
)
print(f"{'-'*40}")
# Display passed operators
if passed_operators:
print(f"\n✅ PASSED OPERATORS ({len(passed_operators)}):")
# Display operators in groups of 10 per line
for i in range(0, len(passed_operators), 10):
line_ops = passed_operators[i : i + 10]
print(" " + ", ".join(line_ops))
else:
print(f"\n✅ PASSED OPERATORS: None")
# Display failed operators (if any)
if failed_operators:
print(f"\n❌ FAILED OPERATORS ({len(failed_operators)}):")
for i in range(0, len(failed_operators), 10):
line_ops = failed_operators[i : i + 10]
print(" " + ", ".join(line_ops))
# Display skipped operators (if any)
if skipped_operators:
print(f"\n⏭️ SKIPPED OPERATORS ({len(skipped_operators)}):")
for i in range(0, len(skipped_operators), 10):
line_ops = skipped_operators[i : i + 10]
print(" " + ", ".join(line_ops))
# Display partial operators (if any)
if partial_operators:
print(f"\n⚠️ PARTIAL OPERATORS ({len(partial_operators)}):")
for i in range(0, len(partial_operators), 10):
line_ops = partial_operators[i : i + 10]
print(" " + ", ".join(line_ops))
if total > 0:
# Calculate success rate based on actual executed tests
executed_tests = passed + failed + partial
if executed_tests > 0:
success_rate = passed / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if verbose and total < total_expected_tests:
print(f"\n💡 Verbose mode: Execution stopped after first failure")
print(f" {total_expected_tests - total} tests were not executed")
if failed == 0:
if skipped > 0 or partial > 0:
print(f"\n⚠️ Tests completed with some operators not implemented")
print(f" - {skipped} tests skipped (both operators not implemented)")
print(f" - {partial} tests partial (one operator not implemented)")
else:
print(f"\n🎉 All tests passed!")
return True
else:
print(f"\n{failed} tests failed")
return False
def list_available_tests(ops_dir=None):
"""List all available operator test files."""
if ops_dir is None:
ops_dir = find_ops_directory()
else:
ops_dir = Path(ops_dir)
if not ops_dir or not ops_dir.exists():
print(f"Error: Ops directory '{ops_dir}' does not exist.")
return
operators = get_available_operators(ops_dir)
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No operator test files found in {ops_dir}")
# Show available Python files for debugging
test_files = list(ops_dir.glob("*.py"))
current_script = Path(__file__).name
test_files = [f for f in test_files if f.name != current_script]
if test_files:
print(f"Available Python files: {[f.name for f in test_files]}")
except:
return False
def generate_help_epilog(ops_dir):
def generate_help_epilog(ops_dir=None):
"""
Generate dynamic help epilog with available operators and hardware platforms.
Args:
ops_dir: Path to ops directory
Returns:
str: Formatted help text
Generate dynamic help epilog containing available operators and hardware platforms.
Maintains the original output format for backward compatibility.
"""
# Get available operators
operators = get_available_operators(ops_dir)
# === Adapter: Use TestDiscoverer to get operator list ===
# Temporarily instantiate a Discoverer just to fetch the list
discoverer = TestDiscoverer(ops_dir)
operators = discoverer.get_available_operators()
# Build epilog text
# Build epilog text (fully replicating original logic)
epilog_parts = []
# Examples section
......@@ -628,17 +162,12 @@ def generate_help_epilog(ops_dir):
def main():
"""Main entry point with comprehensive command line argument parsing."""
# First, find ops directory for dynamic help generation
ops_dir = find_ops_directory()
"""Main entry point for the InfiniCore Operator Test Runner."""
parser = argparse.ArgumentParser(
description="Run InfiniCore operator tests across multiple hardware platforms",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=generate_help_epilog(ops_dir),
epilog=generate_help_epilog(),
)
# Core options
parser.add_argument(
"--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)"
)
......@@ -650,118 +179,97 @@ def main():
action="store_true",
help="List all available test files without running them",
)
# Call common method to add shared arguments (bench, debug, verbose, save...)
add_common_test_args(parser)
# Add common test arguments (including --save, --bench, etc.)
add_common_test_args(parser)
get_hardware_args_group(parser)
# Parse known args first, leave the rest for the test scripts
args, unknown_args = parser.parse_known_args()
# Handle list command
if args.list:
list_available_tests(args.ops_dir)
return
# Auto-detect ops directory if not provided
if args.ops_dir is None:
ops_dir = find_ops_directory()
if not ops_dir:
print(
"Error: Could not auto-detect ops directory. Please specify with --ops-dir"
)
sys.exit(1)
else:
ops_dir = Path(args.ops_dir)
if not ops_dir.exists():
print(f"Error: Ops directory '{ops_dir}' does not exist.")
sys.exit(1)
# Show what extra arguments will be passed
if unknown_args:
print(f"Passing extra arguments to test scripts: {unknown_args}")
# Get available operators for display
available_operators = get_available_operators(ops_dir)
print(f"InfiniCore Operator Test Runner")
print(f"Operating directory: {ops_dir}")
print(f"Available operators: {len(available_operators)}")
# 1. Discovery
discoverer = TestDiscoverer(args.ops_dir)
if args.list:
print("Available operators:", discoverer.get_available_operators())
return
if args.verbose:
print(f"Verbose mode: ENABLED (will stop on first error with full traceback)")
if args.bench:
bench_mode = args.bench if args.bench != "both" else "both"
print(f"Benchmark mode: {bench_mode.upper()} timing")
print(f"Benchmark mode: {args.bench.upper()} timing")
target_ops = None
if args.ops:
# Validate requested operators
valid_ops = []
invalid_ops = []
for op in args.ops:
if op in available_operators:
valid_ops.append(op)
else:
invalid_ops.append(op)
# Get all available operator names
available_ops = set(discoverer.get_available_operators())
requested_ops = set(args.ops)
# Classify using set operations
valid_ops = list(requested_ops & available_ops) # Intersection: Valid ops
invalid_ops = list(requested_ops - available_ops) # Difference: Invalid ops
# Warn if there are invalid operators
if invalid_ops:
print(f"Warning: Unknown operators: {', '.join(invalid_ops)}")
print(f"Available operators: {', '.join(available_operators)}")
print(f"⚠️ Warning: The following requested operators were not found:")
print(f" {', '.join(invalid_ops)}")
print(f" (Use --list to see available operators)")
if valid_ops:
print(f"Testing operators: {', '.join(valid_ops)}")
total_expected_tests = len(valid_ops)
else:
print("No valid operators specified. Running all available tests.")
total_expected_tests = len(available_operators)
else:
print("Testing all available operators")
total_expected_tests = len(available_operators)
print()
# Run all tests
results, cumulative_timing = run_all_op_tests(
ops_dir=ops_dir,
specific_ops=args.ops,
bench=bool(args.bench),
bench_mode=args.bench if args.bench else "both",
verbose=args.verbose,
debug=args.debug,
)
if not valid_ops:
# Case A: User input provided, but ALL were invalid.
print(f"⚠️ No valid operators remained from your list.")
print(f"🔄 Fallback: Proceeding to run ALL available tests...")
# Print summary and exit with appropriate code
all_passed = print_summary(
else:
# Case B: At least some valid operators found.
print(f"🎯 Targeted operators: {', '.join(valid_ops)}")
target_ops = valid_ops
test_files = discoverer.scan(target_ops)
if not test_files:
print("No tests found.")
sys.exit(0)
# 2. Preparation
executor = TestExecutor()
cumulative_timing = TestTiming()
test_summary = TestSummary(args.verbose, args.bench)
results = []
test_summary.print_header(discoverer.ops_dir, len(test_files))
# 3. Execution Loop
for f in test_files:
result = executor.execute(f)
results.append(result)
# Real-time reporting and printing of stdout
test_summary.print_live_result(result)
# Accumulate timing
if result.success:
cumulative_timing.torch_host += result.timing.torch_host
cumulative_timing.infini_host += result.timing.infini_host
cumulative_timing.torch_device += result.timing.torch_device
cumulative_timing.infini_device += result.timing.infini_device
cumulative_timing.operators_tested += 1
# Fail fast in verbose mode
if args.verbose and not result.success:
print("\nStopping due to failure in verbose mode.")
break
# 4. Final Report & Save
all_passed = test_summary.print_summary(
results,
args.verbose,
total_expected_tests,
cumulative_timing,
bench_mode=args.bench if args.bench else "both",
)
# Check if there were any tests with missing implementations
has_missing_implementations = any(
result_data["return_code"] in [-2, -3] for result_data in results.values()
cumulative_timing if args.bench else None,
ops_dir=discoverer.ops_dir,
total_expected=len(test_files),
)
if all_passed and has_missing_implementations:
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if args.verbose and not all_passed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
failed_ops = [
name
for name, result_data in results.items()
if result_data["return_code"] == -1
]
for op in failed_ops[:3]: # Show first 3 failed operators
print(f" python {ops_dir / (op + '.py')} --verbose")
sys.exit(0 if all_passed else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment