Commit 147a4ac7 authored by baominghelly's avatar baominghelly
Browse files

issue/787 - Reconstruct utils && Rename class && Setup summary class

parent 5aa850af
......@@ -11,40 +11,45 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer
from .types import TestTiming, OperatorTestResult, TestResult
from .structs import TestTiming, OperatorResult, CaseResult
from .summary import TestSummary
from .driver import TestDriver
from .printer import ConsolePrinter
from .utils import (
from .utils.compare_utils import (
compare_results,
create_test_comparator,
debug,
get_tolerance,
)
from .utils.json_utils import save_json_report
from .utils.tensor_utils import (
infinicore_tensor_from_torch,
rearrange_tensor,
convert_infinicore_to_torch,
rearrange_tensor,
is_broadcast,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
is_integer_dtype,
)
__all__ = [
# Core types and classes
"BaseOperatorTest",
"CaseResult",
"ConsolePrinter",
"GenericTestRunner",
"InfiniDeviceEnum",
"InfiniDeviceNames",
"OperatorResult",
"TensorInitializer",
"TensorSpec",
"TestCase",
"TestConfig",
"TestResult",
"TestRunner",
"TestDriver",
"TestReporter",
"TestRunner",
"TestTiming",
"OperatorTestResult",
"TestDriver",
"ConsolePrinter",
# Core functions
"add_common_test_args",
"compare_results",
......@@ -57,6 +62,8 @@ __all__ = [
"get_tolerance",
"infinicore_tensor_from_torch",
"rearrange_tensor",
# Json utilites
"save_json_report",
# Utility functions
"to_infinicore_dtype",
"to_torch_dtype",
......
......@@ -8,16 +8,15 @@ import infinicore
import traceback
from abc import ABC, abstractmethod
from .entities import TestCase
from .types import TestResult
from .structs import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .utils.tensor_utils import (
clone_torch_tensor,
create_test_comparator,
infinicore_tensor_from_torch,
)
from .utils.compare_utils import create_test_comparator
from .benchmark import BenchmarkUtils
......@@ -85,7 +84,7 @@ class TestRunner:
try:
print(f"{test_case}")
# Execute test and get TestResult object
# Execute test and get CaseResult object
test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
......@@ -119,8 +118,8 @@ class TestRunner:
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
# Create a failed CaseResult
failed_result = CaseResult(
success=False,
return_code=-1,
error_message=str(e),
......@@ -401,12 +400,12 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
TestResult: Test result object containing status and timing information
CaseResult: Test case result object containing status and timing information
"""
device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
# Initialize test case result
test_result = CaseResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
......
......@@ -5,7 +5,7 @@ Benchmarking utilities for the InfiniCore testing framework
import time
import torch
import infinicore
from .utils import synchronize_device
from .utils.tensor_utils import synchronize_device
class BenchmarkUtils:
......
import torch
import infinicore
from dataclasses import dataclass, field
def to_torch_dtype(infini_dtype):
"""Convert infinicore data type to PyTorch data type"""
......
......@@ -2,7 +2,9 @@ import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .types import OperatorTestResult, TestTiming
from .structs import OperatorResult
from .summary import TestSummary
@contextmanager
def capture_output():
......@@ -15,18 +17,19 @@ def capture_output():
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestDriver:
def drive(self, file_path) -> OperatorTestResult:
result = OperatorTestResult(name=file_path.stem)
def drive(self, file_path) -> OperatorResult:
result = OperatorResult(name=file_path.stem)
try:
# 1. Dynamically import the module
module = self._import_module(file_path)
# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
......@@ -44,11 +47,13 @@ class TestDriver:
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()
# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
self._analyze_return_code(result, test_results)
self._extract_timing(result, test_results)
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
# test_summary._extract_timing(result, test_results)
except Exception as e:
result.success = False
......@@ -76,30 +81,3 @@ class TestDriver:
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None
def _analyze_return_code(self, result, test_results):
# Logic consistent with original code: determine if all passed, partially passed, or skipped
if result.success:
result.return_code = 0
return
has_failures = any(r.return_code == -1 for r in test_results)
has_partial = any(r.return_code == -3 for r in test_results)
has_skipped = any(r.return_code == -2 for r in test_results)
if has_failures:
result.return_code = -1
elif has_partial:
result.return_code = -3
elif has_skipped:
result.return_code = -2
else:
result.return_code = -1
def _extract_timing(self, result, test_results):
# Accumulate timing
t = result.timing
t.torch_host = sum(r.torch_host_time for r in test_results)
t.torch_device = sum(r.torch_device_time for r in test_results)
t.infini_host = sum(r.infini_host_time for r in test_results)
t.infini_device = sum(r.infini_device_time for r in test_results)
from pathlib import Path
class TestDiscoverer:
def __init__(self, ops_dir_path=None):
self.ops_dir = self._resolve_dir(ops_dir_path)
def _resolve_dir(self, path):
if path:
p = Path(path)
if p.exists(): return p
# Default fallback logic: 'ops' directory under the parent of the current file's parent.
# Note: Since this file is in 'framework/', we look at parent.parent.
# It is recommended to pass an explicit path in run.py.
fallback = Path(__file__).parent.parent / "ops"
return fallback if fallback.exists() else None
def get_available_operators(self):
"""Returns a list of names of all available operators."""
if not self.ops_dir: return []
files = self.scan()
return sorted([f.stem for f in files])
def get_raw_python_files(self):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if not self.ops_dir or not self.ops_dir.exists():
return []
files = list(self.ops_dir.glob("*.py"))
# Exclude run.py itself and __init__.py
return [f.name for f in files if f.name != "run.py" and not f.name.startswith("__")]
def scan(self, specific_ops=None):
"""Scans and returns a list of Path objects that meet the criteria."""
if not self.ops_dir or not self.ops_dir.exists():
return []
# 1. Find all .py files
files = list(self.ops_dir.glob("*.py"))
target_ops_set = set(specific_ops) if specific_ops else None
# 2. Filter out non-test files (via content check)
valid_files = []
for f in files:
# A. Basic Name Filtering
if f.name.startswith("_") or f.name == "run.py":
continue
# B. Specific Ops Filtering
if target_ops_set and f.stem not in target_ops_set:
continue
# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if self._is_operator_test(f):
valid_files.append(f)
return valid_files
def _is_operator_test(self, file_path):
"""Checks if the file content contains operator test characteristics."""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
)
except:
return False
# lib/printer.py
import sys
from .types import OperatorTestResult, TestTiming
class ConsolePrinter:
"""
Handles all console output logic.
Acts as the 'View' in the application structure.
"""
def list_tests(self, discoverer):
"""
Intelligently list available tests.
If no valid operators are found, it falls back to listing raw Python files
to assist with debugging (e.g., typos in class inheritance).
"""
ops_dir = discoverer.ops_dir
operators = discoverer.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
# === Fallback Debug Logic ===
raw_files = discoverer.get_raw_python_files()
if raw_files:
print(f"\n💡 Debug Hint: Found the following Python files (but they are not valid tests):")
print(f" {raw_files}")
print(" (Ensure they inherit from 'BaseOperatorTest' and contain 'infinicore')")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result, verbose=False):
"""Print single-line result in real-time."""
print(f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})")
# Only print details if verbose or if the test failed/had output
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or verbose:
print("-" * 40)
def print_summary(
self,
results,
cumulative_timing,
ops_dir,
total_expected=0,
verbose=False,
bench_mode="both"
):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped: print(f"Skipped: {len(skipped)}")
if partial: print(f"Partial: {len(partial)}")
# 1. Print Benchmark data
if cumulative_timing:
# Call the internal helper method
self._print_timing(cumulative_timing, bench_mode=bench_mode)
# 2. Print Detailed Lists
# PASSED
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
# FAILED
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
# SKIPPED
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
# PARTIAL
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Restore Success Rate
if total > 0:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if verbose and failed:
print(f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:")
# Show first 3 failed operators to avoid spamming
for r in failed[:3]:
# Construct file path: ops_dir / filename.py
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
# --- Internal Helpers ---
def _print_op_list(self, title, result_list):
"""Helper to print a formatted list of operator names."""
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
# Group by 10 per line
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
def _print_timing(self, t, bench_mode="both"):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print(f"{'-'*40}")
# Restore Operators Tested field using the dataclass field
if hasattr(t, 'operators_tested') and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY:")
print(f" Operators Tested: {t.operators_tested}")
# Restore detailed Host/Device distinction
if bench_mode in ["host", "both"]:
print(
f" PyTorch Host Total Time: {t.torch_host:12.3f} ms"
)
print(
f" InfiniCore Host Total Time: {t.infini_host:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {t.torch_device:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {t.infini_device:12.3f} ms"
)
print(f"{'-'*40}")
import json
import os
from datetime import datetime
from typing import List, Dict, Any, Union
from dataclasses import is_dataclass
from .base import TensorSpec
from .devices import InfiniDeviceEnum
class TestReporter:
"""
Handles report generation and file saving for test results.
"""
@staticmethod
def prepare_report_entry(
op_name: str,
test_cases: List[Any],
args: Any,
op_paths: Dict[str, str],
results_list: List[Any]
) -> List[Dict[str, Any]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map = {}
if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict):
results_map = results_list
else:
results_map = {0: results_list} if results_list else {}
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries: Dict[int, Dict[str, Any]] = {}
# 3. Iterate Test Cases
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
dev_str = device_id_map.get(dev_id, str(dev_id))
grouped_entries[dev_id] = {
"operator": op_name,
"device": dev_str,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": []
}
# --- B. Build Kwargs ---
display_kwargs = {}
for k, v in tc.kwargs.items():
# 1. Handle Inplace output index: "out": 0 -> "out": "in_0" / "a_spec"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
# Prioritize the input's name; otherwise, default to index-based name
display_kwargs[k] = getattr(tc.inputs[v], "name", None) or f"in_{v}"
else:
display_kwargs[k] = f"Invalid_Index_{v}"
# 2. Handle TensorSpec objects
elif isinstance(v, TensorSpec):
spec_dict = TestReporter._spec_to_dict(v)
# If the object has a name, explicitly overwrite it; otherwise, keep original
if getattr(v, "name", None):
spec_dict["name"] = v.name
display_kwargs[k] = spec_dict
# 3. Direct assignment for other types
else:
display_kwargs[k] = v
# --- B2. Inject Outputs ---
# Handle output list (output_specs)
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
out_dict = TestReporter._spec_to_dict(spec)
# Prioritize intrinsic name; otherwise, default to "out_i"
out_dict["name"] = getattr(spec, "name", None) or f"out_{i}"
display_kwargs[f"out_{i}"] = out_dict
# Handle single output (output_spec), preventing overwrite of existing "out"
elif tc.output_spec and "out" not in display_kwargs:
out_dict = TestReporter._spec_to_dict(tc.output_spec)
# Prioritize intrinsic name; otherwise, default to "out" (fixes null issue)
out_dict["name"] = getattr(tc.output_spec, "name", "out")
display_kwargs["out"] = out_dict
# --- C. Build Inputs ---
# Iterate inputs: prioritize original name, fallback to "in_i"
processed_inputs = []
for i, inp in enumerate(tc.inputs):
inp_dict = TestReporter._spec_to_dict(inp)
# Simplified logic: Use "name" attribute if present and non-empty, else use f"in_{i}"
inp_dict["name"] = getattr(inp, "name", None) or f"in_{i}"
processed_inputs.append(inp_dict)
case_data = {
"description": tc.description,
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
# --- D. Inject Result ---
if res:
case_data["result"] = TestReporter._fmt_result(res)
else:
case_data["result"] = {"status": {"success": False, "error": "No result"}}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
@staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define indentation levels for cleaner code
indent_4 = ' ' * 4
indent_8 = ' ' * 8
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{indent_4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if key == "testcases" and isinstance(val, list):
f.write(f'{indent_8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{indent_12}{{\n")
case_keys = list(case_item.keys())
for k_idx, c_key in enumerate(case_keys):
c_val = case_item[c_key]
# [Logic A] Skip fields we merged manually after 'kwargs'
if c_key in ["comparison_target", "tolerance"]:
continue
# Check comma for standard logic (might be overridden below)
c_comma = "," if k_idx < len(case_keys) - 1 else ""
# [Logic B] Handle 'kwargs' + Grouped Fields
if c_key == "kwargs":
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
line_comma = "," if remaining_keys else ""
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if c_key == "inputs" and isinstance(c_val, list):
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else:
c_val_str = json.dumps(c_val, ensure_ascii=False)
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
close_comma = "," if c_idx < len(val) - 1 else ""
f.write(f"{indent_12}}}{close_comma}\n")
f.write(f"{indent_8}]{comma}\n")
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else:
k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(val, ensure_ascii=False)
f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1:
f.write(f"{indent_4}}},\n")
else:
f.write(f"{indent_4}}}\n")
f.write("]\n")
print(f" ✅ Saved (Structure Matched).")
except Exception as e:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}")
# --- Internal Helpers ---
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = '{' if is_dict else '['
close_char = '}' if is_dict else ']'
f.write(f'{indent}"{key}": {open_char}')
# Normalize items for iteration
if is_dict:
items = list(value.items())
else:
items = value # List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
# Format individual item string
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = (i == len(items) - 1)
item_comma = "" if is_last else ", "
# Predict new length: current + item + comma
if current_len + len(item_str) + len(item_comma) > 180:
# Wrap to new line
f.write(f'\n{sub_indent}')
current_len = len(sub_indent)
f.write(f'{item_str}{item_comma}')
current_len += len(item_str) + len(item_comma)
f.write(f'{close_char}{close_comma}\n')
@staticmethod
def _spec_to_dict(s):
return {
"name": getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
@staticmethod
def _fmt_result(res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
# import json
# import os
# from datetime import datetime
# from typing import List, Dict, Any, Union
# from dataclasses import is_dataclass
# from .base import TensorSpec
# from .devices import InfiniDeviceEnum
# class TestReporter:
# """
# Handles report generation and file saving for test results.
# """
# @staticmethod
# def prepare_report_entry(
# op_name: str,
# test_cases: List[Any],
# args: Any,
# op_paths: Dict[str, str],
# results_list: List[Any],
# ) -> List[Dict[str, Any]]:
# """
# Combines static test case info with dynamic execution results.
# """
# # 1. Normalize results
# results_map = {}
# if isinstance(results_list, list):
# results_map = {i: res for i, res in enumerate(results_list)}
# elif isinstance(results_list, dict):
# results_map = results_list
# else:
# results_map = {0: results_list} if results_list else {}
# # 2. Global Args
# global_args = {
# k: getattr(args, k)
# for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
# if hasattr(args, k)
# }
# grouped_entries: Dict[int, Dict[str, Any]] = {}
# # 3. Iterate Test Cases
# for idx, tc in enumerate(test_cases):
# res = results_map.get(idx)
# dev_id = getattr(res, "device", 0) if res else 0
# # --- A. Initialize Group ---
# if dev_id not in grouped_entries:
# device_id_map = {
# v: k
# for k, v in vars(InfiniDeviceEnum).items()
# if not k.startswith("_")
# }
# dev_str = device_id_map.get(dev_id, str(dev_id))
# grouped_entries[dev_id] = {
# "operator": op_name,
# "device": dev_str,
# "torch_op": op_paths.get("torch") or "unknown",
# "infinicore_op": op_paths.get("infinicore") or "unknown",
# "args": global_args,
# "testcases": [],
# }
# # --- B. Build Kwargs ---
# display_kwargs = {}
# for k, v in tc.kwargs.items():
# # 1. Handle Inplace output index: "out": 0 -> "out": "in_0" / "a_spec"
# if k == "out" and isinstance(v, int):
# if 0 <= v < len(tc.inputs):
# # Prioritize the input's name; otherwise, default to index-based name
# display_kwargs[k] = (
# getattr(tc.inputs[v], "name", None) or f"in_{v}"
# )
# else:
# display_kwargs[k] = f"Invalid_Index_{v}"
# # 2. Handle TensorSpec objects
# elif isinstance(v, TensorSpec):
# spec_dict = TestReporter._spec_to_dict(v)
# # If the object has a name, explicitly overwrite it; otherwise, keep original
# if v.name:
# spec_dict["name"] = v.name
# display_kwargs[k] = spec_dict
# # 3. Direct assignment for other types
# else:
# display_kwargs[k] = v
# # --- B2. Inject Outputs ---
# # Handle output list (output_specs)
# if getattr(tc, "output_specs", None):
# for i, spec in enumerate(tc.output_specs):
# out_dict = TestReporter._spec_to_dict(spec)
# # Prioritize intrinsic name; otherwise, default to "out_i"
# out_dict["name"] = getattr(spec, "name", None) or f"out_{i}"
# display_kwargs[f"out_{i}"] = out_dict
# # Handle single output (output_spec), preventing overwrite of existing "out"
# elif tc.output_spec and "out" not in display_kwargs:
# out_dict = TestReporter._spec_to_dict(tc.output_spec)
# # Prioritize intrinsic name; otherwise, default to "out" (fixes null issue)
# out_dict["name"] = getattr(tc.output_spec, "name", "out")
# display_kwargs["out"] = out_dict
# # --- C. Build Inputs ---
# # Iterate inputs: prioritize original name, fallback to "in_i"
# processed_inputs = []
# for i, inp in enumerate(tc.inputs):
# inp_dict = TestReporter._spec_to_dict(inp)
# # Simplified logic: Use "name" attribute if present and non-empty, else use f"in_{i}"
# inp_dict["name"] = getattr(inp, "name", None) or f"in_{i}"
# processed_inputs.append(inp_dict)
# case_data = {
# "description": tc.description,
# "inputs": processed_inputs,
# "kwargs": display_kwargs,
# "comparison_target": tc.comparison_target,
# "tolerance": tc.tolerance,
# }
# # --- D. Inject Result ---
# if res:
# case_data["result"] = TestReporter._fmt_result(res)
# else:
# case_data["result"] = {
# "status": {"success": False, "error": "No result"}
# }
# grouped_entries[dev_id]["testcases"].append(case_data)
# return list(grouped_entries.values())
# @staticmethod
# def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
# """
# Saves the report list to a JSON file with specific custom formatting
# """
# directory, filename = os.path.split(save_path)
# name, ext = os.path.splitext(filename)
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
# final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# # Define indentation levels for cleaner code
# indent_4 = " " * 4
# indent_8 = " " * 8
# indent_12 = " " * 12
# indent_16 = " " * 16
# indent_20 = " " * 20
# print(f"💾 Saving to: {final_path}")
# try:
# with open(final_path, "w", encoding="utf-8") as f:
# f.write("[\n")
# for i, entry in enumerate(total_results):
# f.write(f"{indent_4}{{\n")
# keys = list(entry.keys())
# for j, key in enumerate(keys):
# val = entry[key]
# comma = "," if j < len(keys) - 1 else ""
# # -------------------------------------------------
# # Special Handling for 'testcases' list formatting
# # -------------------------------------------------
# if key == "testcases" and isinstance(val, list):
# f.write(f'{indent_8}"{key}": [\n')
# for c_idx, case_item in enumerate(val):
# f.write(f"{indent_12}{{\n")
# case_keys = list(case_item.keys())
# for k_idx, c_key in enumerate(case_keys):
# c_val = case_item[c_key]
# # [Logic A] Skip fields we merged manually after 'kwargs'
# if c_key in ["comparison_target", "tolerance"]:
# continue
# # Check comma for standard logic (might be overridden below)
# c_comma = "," if k_idx < len(case_keys) - 1 else ""
# # [Logic B] Handle 'kwargs' + Grouped Fields
# if c_key == "kwargs":
# # 1. Use Helper for kwargs (Fill/Flow logic)
# TestReporter._write_smart_field(
# f,
# c_key,
# c_val,
# indent_16,
# indent_20,
# close_comma=",",
# )
# # 2. Write subsequent comparison_target and tolerance (on a new line)
# cmp_v = json.dumps(
# case_item.get("comparison_target"),
# ensure_ascii=False,
# )
# tol_v = json.dumps(
# case_item.get("tolerance"),
# ensure_ascii=False,
# )
# remaining_keys = [
# k
# for k in case_keys[k_idx + 1 :]
# if k
# not in ("comparison_target", "tolerance")
# ]
# line_comma = "," if remaining_keys else ""
# f.write(
# f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n'
# )
# continue
# # [Logic C] Handle 'inputs' (Smart Wrap)
# if c_key == "inputs" and isinstance(c_val, list):
# TestReporter._write_smart_field(
# f,
# c_key,
# c_val,
# indent_16,
# indent_20,
# close_comma=c_comma,
# )
# continue
# # [Logic D] Standard fields (description, result, output_spec, etc.)
# else:
# c_val_str = json.dumps(
# c_val, ensure_ascii=False
# )
# f.write(
# f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n'
# )
# close_comma = "," if c_idx < len(val) - 1 else ""
# f.write(f"{indent_12}}}{close_comma}\n")
# f.write(f"{indent_8}]{comma}\n")
# # -------------------------------------------------
# # Standard top-level fields (operator, args, etc.)
# # -------------------------------------------------
# else:
# k_str = json.dumps(key, ensure_ascii=False)
# v_str = json.dumps(val, ensure_ascii=False)
# f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
# if i < len(total_results) - 1:
# f.write(f"{indent_4}}},\n")
# else:
# f.write(f"{indent_4}}}\n")
# f.write("]\n")
# print(f" ✅ Saved (Structure Matched).")
# except Exception as e:
# import traceback
# traceback.print_exc()
# print(f" ❌ Save failed: {e}")
# # --- Internal Helpers ---
# @staticmethod
# def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
# """
# Helper to write a JSON field (List or Dict) with smart wrapping.
# - If compact length <= 180: Write on one line.
# - If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
# """
# # 1. Try Compact Mode
# compact_json = json.dumps(value, ensure_ascii=False)
# if len(compact_json) <= 180:
# f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
# return
# # 2. Fill/Flow Mode
# is_dict = isinstance(value, dict)
# open_char = "{" if is_dict else "["
# close_char = "}" if is_dict else "]"
# f.write(f'{indent}"{key}": {open_char}')
# # Normalize items for iteration
# if is_dict:
# items = list(value.items())
# else:
# items = value # List
# # Initialize current line length tracking
# # Length includes indent + "key": [
# current_len = len(indent) + len(f'"{key}": {open_char}')
# for i, item in enumerate(items):
# # Format individual item string
# if is_dict:
# k, v = item
# val_str = json.dumps(v, ensure_ascii=False)
# item_str = f'"{k}": {val_str}'
# else:
# item_str = json.dumps(item, ensure_ascii=False)
# is_last = i == len(items) - 1
# item_comma = "" if is_last else ", "
# # Predict new length: current + item + comma
# if current_len + len(item_str) + len(item_comma) > 180:
# # Wrap to new line
# f.write(f"\n{sub_indent}")
# current_len = len(sub_indent)
# f.write(f"{item_str}{item_comma}")
# current_len += len(item_str) + len(item_comma)
# f.write(f"{close_char}{close_comma}\n")
# @staticmethod
# def _spec_to_dict(s):
# return {
# "name": getattr(s, "name", "unknown"),
# "shape": list(s.shape) if s.shape else None,
# "dtype": str(s.dtype).split(".")[-1],
# "strides": list(s.strides) if s.strides else None,
# }
# @staticmethod
# def _fmt_result(res):
# if not (is_dataclass(res) or hasattr(res, "success")):
# return str(res)
# get_time = lambda k: round(getattr(res, k, 0.0), 4)
# return {
# "status": {
# "success": getattr(res, "success", False),
# "error": getattr(res, "error_message", ""),
# },
# "perf_ms": {
# "torch": {
# "host": get_time("torch_host_time"),
# "device": get_time("torch_device_time"),
# },
# "infinicore": {
# "host": get_time("infini_host_time"),
# "device": get_time("infini_device_time"),
# },
# },
# }
......@@ -7,7 +7,7 @@ import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
from .summary import TestSummary
class GenericTestRunner:
......@@ -89,7 +89,8 @@ class GenericTestRunner:
op_paths = {"torch": t_path, "infinicore": i_path}
# 2. Generate Report Entries
entries = TestReporter.prepare_report_entry(
test_summary = TestSummary()
entries = test_summary.collect_report_entry(
op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases,
args=self.args,
......@@ -98,7 +99,7 @@ class GenericTestRunner:
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, entries)
test_summary.save_report(self.args.save)
except Exception as e:
import traceback
......
from dataclasses import dataclass, field
from typing import Any
# TODO: Rename it, current class name is abstract.
@dataclass
class TestResult:
"""Test result data structure"""
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
......@@ -16,19 +16,23 @@ class TestResult:
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
operators_tested: int = 0
@dataclass
class OperatorTestResult:
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
......@@ -39,14 +43,20 @@ class OperatorTestResult:
@property
def status_icon(self):
if self.return_code == 0: return "✅"
if self.return_code == -2: return "⏭️"
if self.return_code == -3: return "⚠️"
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0: return "PASSED"
if self.return_code == -2: return "SKIPPED"
if self.return_code == -3: return "PARTIAL"
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
from typing import List, Dict, Any
from dataclasses import is_dataclass
from .devices import InfiniDeviceEnum
from .base import TensorSpec
from .utils.json_utils import save_json_report
class TestSummary:
"""
Test summary:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
"""
def __init__(self, verbose=False, bench_mode=None):
self.verbose = verbose
self.bench_mode = bench_mode
self.report_entries = [] # Cache for JSON report
# =========================================================
# Part 1: Result Aggregation
# =========================================================
def process_operator_result(self, op_result, sub_results: List):
"""
Updates the OperatorResult object in-place.
"""
if not sub_results:
op_result.return_code = -1
return
# 1. Analyze Return Code (Status)
if op_result.success:
op_result.return_code = 0
else:
has_failures = any(r.return_code == -1 for r in sub_results)
has_partial = any(r.return_code == -3 for r in sub_results)
has_skipped = any(r.return_code == -2 for r in sub_results)
if has_failures:
op_result.return_code = -1
elif has_partial:
op_result.return_code = -3
elif has_skipped:
op_result.return_code = -2
else:
op_result.return_code = -1
# 2. Extract Timing (Aggregation)
t = op_result.timing
t.torch_host = sum(r.torch_host_time for r in sub_results)
t.torch_device = sum(r.torch_device_time for r in sub_results)
t.infini_host = sum(r.infini_host_time for r in sub_results)
t.infini_device = sum(r.infini_device_time for r in sub_results)
t.operators_tested = len(sub_results)
# =========================================================
# Part 2: Console Output (View)
# =========================================================
def list_tests(self, discoverer):
ops_dir = discoverer.ops_dir
operators = discoverer.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
raw_files = discoverer.get_raw_python_files()
if raw_files:
print(
f"\n💡 Debug Hint: Found Python files but they are not valid tests:"
)
print(f" {raw_files}")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result):
print(
f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})"
)
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or self.verbose:
print("-" * 40)
def print_summary(self, results, cumulative_timing, ops_dir, total_expected=0):
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped:
print(f"Skipped: {len(skipped)}")
if partial:
print(f"Partial: {len(partial)}")
# 1. Benchmark
if cumulative_timing:
self._print_timing(cumulative_timing)
# 2. Lists
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Verdict
if total > 0:
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if self.verbose and failed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
for r in failed[:3]:
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
def _print_timing(self, t):
print(f"{'-'*40}")
if hasattr(t, "operators_tested") and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY ({t.operators_tested} cases):")
if self.bench_mode in ["host", "both"]:
print(f" [Host] PyTorch: {t.torch_host:10.3f} ms")
print(f" [Host] InfiniCore: {t.infini_host:10.3f} ms")
if self.bench_mode in ["device", "both"]:
print(f" [Device] PyTorch: {t.torch_device:10.3f} ms")
print(f" [Device] InfiniCore: {t.infini_device:10.3f} ms")
print(f"{'-'*40}")
def _print_op_list(self, title, result_list):
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
# =========================================================
# Part 3: Report Generation
# =========================================================
def collect_report_entry(self, op_name, test_cases, args, op_paths, results_list):
"""
Prepares the data and adds it to the internal list.
"""
entry = self._prepare_entry_logic(
op_name, test_cases, args, op_paths, results_list
)
self.report_entries.extend(entry)
def save_report(self, save_path):
"""
Delegates the actual writing to save_json_report.
"""
if not self.report_entries:
return
# Call the external utility
save_json_report(save_path, self.report_entries)
def _prepare_entry_logic(self, op_name, test_cases, args, op_paths, results_list):
"""
Combines static test case info with dynamic execution results.
Refactored to reduce duplication.
"""
# 1. Normalize results
results_map = (
results_list
if isinstance(results_list, dict)
else {i: res for i, res in enumerate(results_list or [])}
)
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries = {}
# Cache device enum map
device_id_map = {
v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")
}
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
grouped_entries[dev_id] = {
"operator": op_name,
"device": device_id_map.get(dev_id, str(dev_id)),
"torch_op": op_paths.get("torch", "unknown"),
"infinicore_op": op_paths.get("infinicore", "unknown"),
"args": global_args,
"testcases": [],
}
# --- B. Helpers for Spec Processing ---
def process_spec(spec, default_name):
final_name = self._resolve_name(spec, default_name)
# Call internal method (no need for external converters file)
return self._spec_to_dict(spec, name=final_name)
# --- C. Build Inputs ---
processed_inputs = [
process_spec(inp, f"in_{i}") for i, inp in enumerate(tc.inputs)
]
# --- D. Build Kwargs ---
display_kwargs = {}
for k, v in tc.kwargs.items():
if k == "out" and isinstance(v, int):
# Handle Inplace Index
if 0 <= v < len(tc.inputs):
display_kwargs[k] = self._resolve_name(tc.inputs[v], f"in_{v}")
else:
display_kwargs[k] = f"Invalid_Index_{v}"
elif isinstance(v, TensorSpec):
display_kwargs[k] = process_spec(v, v.name)
else:
display_kwargs[k] = v
# --- E. Inject Outputs ---
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = process_spec(spec, f"out_{i}")
elif tc.output_spec and "out" not in display_kwargs:
display_kwargs["out"] = process_spec(tc.output_spec, "out")
# --- F. Assemble Case Data ---
case_data = {
"description": tc.description,
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
"result": (
self._fmt_result(res)
if res
else {"status": {"success": False, "error": "No result"}}
),
}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
# --- Internal Helpers ---
def _resolve_name(self, obj, default_name):
return getattr(obj, "name", None) or default_name
def _spec_to_dict(self, s, name=None):
return {
"name": name if name else getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
def _fmt_result(self, res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
......@@ -3,7 +3,7 @@ import math
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype, is_complex_dtype
from .utils.tensor_utils import is_integer_dtype, is_complex_dtype
class TensorInitializer:
......
import torch
import time
import infinicore
import numpy as np
from .datatypes import to_infinicore_dtype, to_torch_dtype
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
return diff_indices
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Get tolerance settings based on data type
"""
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
import sys
from ..datatypes import to_torch_dtype
from .tensor_utils import (
convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
)
def compare_results(
......@@ -349,89 +189,105 @@ def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
return compare_test_results
def rearrange_tensor(tensor, new_strides):
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
Get tolerance settings based on data type
"""
import torch
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] > 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
import torch
import sys
# Copy the original data to the new tensor
new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
return new_tensor
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
Args:
strides: Tensor strides or None
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
return diff_indices
import json
import os
from datetime import datetime
def save_json_report(save_path, total_results):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define Indentation
I4, I8, I12, I16, I20 = " " * 4, " " * 8, " " * 12, " " * 16, " " * 20
print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition
def _j(obj):
return json.dumps(obj, ensure_ascii=False)
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{I4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# Special handling for 'testcases' list
if key == "testcases" and isinstance(val, list):
f.write(f'{I8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{I12}{{\n")
case_keys = list(case_item.keys())
# Filter out keys that we handle specially at the end
standard_keys = [
k
for k in case_keys
if k not in ["comparison_target", "tolerance"]
]
for k_idx, c_key in enumerate(standard_keys):
c_val = case_item[c_key]
# Determine comma logic
c_comma = (
","
if k_idx < len(standard_keys) - 1
or "comparison_target" in case_item
else ""
)
if c_key in ["kwargs", "inputs"]:
_write_smart_field(
f, c_key, c_val, I16, I20, close_comma=c_comma
)
else:
f.write(f'{I16}"{c_key}": {_j(c_val)}{c_comma}\n')
# Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item:
cmp = _j(case_item.get("comparison_target"))
tol = _j(case_item.get("tolerance"))
f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
)
close_case = "," if c_idx < len(val) - 1 else ""
f.write(f"{I12}}}{close_case}\n")
f.write(f"{I8}]{comma}\n")
else:
# Standard top-level fields
f.write(f"{I8}{_j(key)}: {_j(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n")
f.write("]\n")
print(f" ✅ Saved.")
except Exception as e:
import traceback
traceback.print_exc()
print(f" ❌ Save failed: {e}")
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Internal Helper: Write a JSON field with smart wrapping.
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = "{" if is_dict else "["
close_char = "}" if is_dict else "]"
f.write(f'{indent}"{key}": {open_char}')
if is_dict:
items = list(value.items())
else:
items = value
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = i == len(items) - 1
item_comma = "" if is_last else ", "
if current_len + len(item_str) + len(item_comma) > 180:
f.write(f"\n{sub_indent}")
current_len = len(sub_indent)
f.write(f"{item_str}{item_comma}")
current_len += len(item_str) + len(item_comma)
f.write(f"{close_char}{close_comma}\n")
import torch
import infinicore
from ..datatypes import to_infinicore_dtype, to_torch_dtype
# =================================================================
# Device & Synchronization
# =================================================================
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
# =================================================================
# Tensor Operations & Conversions
# =================================================================
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
def rearrange_tensor(tensor, new_strides):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
"""
import torch
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] > 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
# Copy the original data to the new tensor
new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
# =================================================================
# Type Checks (Moved here to avoid circular imports in check.py)
# =================================================================
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
......@@ -76,7 +76,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -222,7 +222,7 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import TestResult
from framework.base import CaseResult
from framework.utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
......@@ -268,8 +268,8 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......@@ -283,8 +283,8 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
# Return a successful TestResult object
return TestResult(
# Return a successful CaseResult object
return CaseResult(
success=True,
return_code=0,
test_case=test_case,
......
......@@ -180,7 +180,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -193,7 +193,7 @@ class OpTest(BaseOperatorTest):
)
for spec in output_specs:
if isinstance(spec, TensorSpec) and spec.strides is not None:
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
......@@ -122,7 +122,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.inputs[0], TensorSpec)
and test_case.inputs[0].strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......@@ -135,7 +135,7 @@ class OpTest(BaseOperatorTest):
and isinstance(test_case.output_spec, TensorSpec)
and test_case.output_spec.strides is not None
):
return TestResult(
return CaseResult(
success=False,
return_code=-2,
test_case=test_case,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment