Commit 293c7906 authored by baominghelly's avatar baominghelly
Browse files

Add save feature for existing test cases

parent f73d6237
......@@ -36,6 +36,7 @@ __all__ = [
"TestConfig",
"TestResult",
"TestRunner",
"TestReporter",
# Core functions
"compare_results",
"convert_infinicore_to_torch",
......
......@@ -108,6 +108,14 @@ Examples:
help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument(
"--save",
nargs="?",
const="test_report.json",
default=None,
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
)
# Device options using shared hardware info
hardware_group = get_hardware_args_group(parser)
args, unknown = parser.parse_known_args()
......
import json
import time
import os
from typing import List, Dict, Any
from dataclasses import is_dataclass
from .base import TensorSpec
from .devices import InfiniDeviceEnum
class TestReporter:
"""
Handles report generation and file saving for test results.
"""
@staticmethod
def prepare_report_entry(
op_name: str,
test_cases: List[Any],
args: Any,
op_paths: Dict[str, str],
device: str,
results_list: List[Any]
) -> Dict[str, Any]:
"""
Combines static test case info with dynamic execution results.
"""
# Map results by index
results_map = {}
if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict):
results_map = results_list
else:
results_map = {0: results_list}
processed_cases = []
for idx, tc in enumerate(test_cases):
# 1. Reconstruct case dict (Static info)
case_data = {
"description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"kwargs": {
k: (
TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v
)
for k, v in tc.kwargs.items()
},
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
if tc.output_spec:
case_data["output_spec"] = TestReporter._spec_to_dict(tc.output_spec)
if hasattr(tc, "output_specs") and tc.output_specs:
case_data["output_specs"] = [
TestReporter._spec_to_dict(s) for s in tc.output_specs
]
# 2. Inject Result (Dynamic info) directly into the case
res = results_map.get(idx)
if res:
case_data["result"] = TestReporter._fmt_result(res)
else:
case_data["result"] = {"status": {"success": False, "error": "No result"}}
processed_cases.append(case_data)
# Global Arguments
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
return {
"operator": op_name,
"device": device,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": processed_cases
}
@staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
"""
Saves the report list to a JSON file with compact formatting.
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = time.strftime("%Y%m%d_%H%M%S")
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(" {\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
# Special Handling for list fields: vertical expansion
if key in ["testcases"] and isinstance(entry[key], list):
f.write(f' "{key}": [\n')
sub_list = entry[key]
for c_idx, c_item in enumerate(sub_list):
c_str = json.dumps(c_item, ensure_ascii=False)
comma = "," if c_idx < len(sub_list) - 1 else ""
f.write(f" {c_str}{comma}\n")
list_comma = "," if j < len(keys) - 1 else ""
f.write(f" ]{list_comma}\n")
else:
# Standard compact formatting
k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(entry[key], ensure_ascii=False)
comma = "," if j < len(keys) - 1 else ""
f.write(f" {k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1:
f.write(" },\n")
else:
f.write(" }\n")
f.write("]\n")
print(f" ✅ Saved (Structure Matched).")
except Exception as e:
print(f" ❌ Save failed: {e}")
# --- Internal Helpers ---
@staticmethod
def _spec_to_dict(s):
return {
"name": getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
@staticmethod
def _fmt_result(res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
device_id_map = {
v: k
for k, v in vars(InfiniDeviceEnum).items()
if not k.startswith("_")
}
raw_id = getattr(res, "device", None)
dev_str = device_id_map.get(raw_id, str(raw_id))
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": round(getattr(res, "torch_host_time", 0.0), 4),
"device": round(getattr(res, "torch_device_time", 0.0), 4),
},
"infinicore": {
"host": round(getattr(res, "infini_host_time", 0.0), 4),
"device": round(getattr(res, "infini_device_time", 0.0), 4),
},
},
"device": dev_str,
}
......@@ -3,19 +3,22 @@ Generic test runner that handles the common execution flow for all operators
"""
import sys
import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
class GenericTestRunner:
"""Generic test runner that handles the common execution flow"""
def __init__(self, operator_test_class):
def __init__(self, operator_test_class, args=None):
"""
Args:
operator_test_class: A class that implements BaseOperatorTest interface
"""
self.operator_test = operator_test_class()
self.args = get_args()
self.args = args or get_args()
def run(self):
"""Execute the complete test suite
......@@ -63,4 +66,74 @@ class GenericTestRunner:
1: One or more tests failed
"""
success, runner = self.run()
if getattr(self.args, 'save', None):
self._save_report(runner)
sys.exit(0 if success else 1)
def _save_report(self, runner):
"""
Helper method to collect metadata and trigger report saving.
"""
try:
# 1. Infer active device string dynamically
from .devices import InfiniDeviceEnum
# Get actual device IDs used (e.g. [0, 1])
device_ids = get_test_devices(self.args)
# Map IDs to Names (e.g. {0: "CPU", 1: "NVIDIA"})
id_to_name = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith('_')}
# Convert list of IDs to list of Names
device_names = [id_to_name.get(d_id, str(d_id)) for d_id in device_ids]
device_str = ", ".join(device_names) if device_names else "CPU"
# 2. Prepare metadata (Paths)
# Try to infer from source code first
t_path = self._infer_op_path(self.operator_test.torch_operator, "torch")
i_path = self._infer_op_path(self.operator_test.infinicore_operator, "infinicore")
op_paths = {
"torch": t_path,
"infinicore": i_path
}
# 3. Generate Report Entry
entry = TestReporter.prepare_report_entry(
op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases,
args=self.args,
op_paths=op_paths,
device=device_str,
results_list=runner.test_results
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, [entry])
except Exception as e:
import traceback; traceback.print_exc()
print(f"⚠️ Failed to save report: {e}")
def _infer_op_path(self, method, lib_prefix):
"""
Introspects the method source code to find calls like 'torch.add' or 'infinicore.mul'.
Returns the full path string (e.g., 'torch.add') or None if not found.
"""
try:
source = inspect.getsource(method)
# Regex to find 'lib.func' or 'lib.submodule.func'
# Matches: 'torch.add', 'torch.nn.functional.relu'
pattern = re.compile(rf"\b{lib_prefix}\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)", re.IGNORECASE)
match = pattern.search(source)
if match:
# Return the matched string exactly as found in source code
# or normalize it (e.g. lowercase lib_prefix + match)
return f"{lib_prefix}.{match.group(1)}"
except Exception:
# Handle cases where source is not available (e.g. compiled modules)
pass
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment