Commit a8875c9a authored by baominghelly's avatar baominghelly
Browse files

Issue/716 - change test case indentation in report && add save feature support in run.py

parent 293c7906
...@@ -2,6 +2,7 @@ from .base import TestConfig, TestRunner, BaseOperatorTest ...@@ -2,6 +2,7 @@ from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult from .test_case import TestCase, TestResult
from .benchmark import BenchmarkUtils, BenchmarkResult from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import ( from .config import (
add_common_test_args,
get_args, get_args,
get_hardware_args_group, get_hardware_args_group,
get_test_devices, get_test_devices,
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"TestRunner", "TestRunner",
"TestReporter", "TestReporter",
# Core functions # Core functions
"add_common_test_args",
"compare_results", "compare_results",
"convert_infinicore_to_torch", "convert_infinicore_to_torch",
"create_test_comparator", "create_test_comparator",
......
...@@ -44,6 +44,42 @@ def get_hardware_args_group(parser): ...@@ -44,6 +44,42 @@ def get_hardware_args_group(parser):
return hardware_group return hardware_group
def add_common_test_args(parser: argparse.ArgumentParser):
"""
Adds common test/execution arguments to the passed parser object.
Includes: bench, debug, verbose, save args.
"""
# Create an argument group to make help info clearer
group = parser.add_argument_group("Common Execution Options")
group.add_argument(
"--bench",
nargs="?",
const="both",
choices=["host", "device", "both"],
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
)
group.add_argument(
"--debug",
action="store_true",
help="Enable debug mode for detailed tensor comparison",
)
group.add_argument(
"--verbose",
action="store_true",
help="Enable verbose mode to stop on first error with full traceback",
)
group.add_argument(
"--save",
nargs="?",
const="test_report.json",
default=None,
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
)
def get_args(): def get_args():
"""Parse command line arguments for operator testing""" """Parse command line arguments for operator testing"""
...@@ -77,14 +113,6 @@ Examples: ...@@ -77,14 +113,6 @@ Examples:
) )
# Core testing options # Core testing options
parser.add_argument(
"--bench",
nargs="?",
const="both",
choices=["host", "device", "both"],
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
)
parser.add_argument( parser.add_argument(
"--num_prerun", "--num_prerun",
type=lambda x: max(0, int(x)), type=lambda x: max(0, int(x)),
...@@ -97,24 +125,9 @@ Examples: ...@@ -97,24 +125,9 @@ Examples:
default=1000, default=1000,
help="Number of iterations for benchmarking (default: 1000)", help="Number of iterations for benchmarking (default: 1000)",
) )
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode for detailed tensor comparison",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument( # Call the common method to add arguments
"--save", add_common_test_args(parser)
nargs="?",
const="test_report.json",
default=None,
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
)
# Device options using shared hardware info # Device options using shared hardware info
hardware_group = get_hardware_args_group(parser) hardware_group = get_hardware_args_group(parser)
......
import json import json
import time
import os import os
from typing import List, Dict, Any from datetime import datetime
from typing import List, Dict, Any, Union
from dataclasses import is_dataclass from dataclasses import is_dataclass
from .base import TensorSpec from .base import TensorSpec
from .devices import InfiniDeviceEnum from .devices import InfiniDeviceEnum
...@@ -17,121 +17,246 @@ class TestReporter: ...@@ -17,121 +17,246 @@ class TestReporter:
test_cases: List[Any], test_cases: List[Any],
args: Any, args: Any,
op_paths: Dict[str, str], op_paths: Dict[str, str],
device: str,
results_list: List[Any] results_list: List[Any]
) -> Dict[str, Any]: ) -> List[Dict[str, Any]]:
""" """
Combines static test case info with dynamic execution results. Combines static test case info with dynamic execution results.
""" """
# Map results by index # 1. Normalize results
results_map = {} results_map = {}
if isinstance(results_list, list): if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)} results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict): elif isinstance(results_list, dict):
results_map = results_list results_map = results_list
else: else:
results_map = {0: results_list} results_map = {0: results_list} if results_list else {}
processed_cases = [] # 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries: Dict[int, Dict[str, Any]] = {}
# 3. Iterate Test Cases
for idx, tc in enumerate(test_cases): for idx, tc in enumerate(test_cases):
# 1. Reconstruct case dict (Static info) res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
dev_str = device_id_map.get(dev_id, str(dev_id))
grouped_entries[dev_id] = {
"operator": op_name,
"device": dev_str,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": []
}
# --- B. Build Kwargs ---
display_kwargs = {}
# B1. Process existing kwargs
for k, v in tc.kwargs.items():
# Handle Inplace: "out": index -> "out": "input_name"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
display_kwargs[k] = tc.inputs[v].name
else:
display_kwargs[k] = f"Invalid_Index_{v}"
else:
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
# B2. Inject Outputs into Kwargs
if hasattr(tc, "output_specs") and tc.output_specs:
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
elif tc.output_spec:
if "out" not in display_kwargs:
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
# --- C. Build Test Case Dictionary ---
case_data = { case_data = {
"description": tc.description, "description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs], "inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"kwargs": { "kwargs": display_kwargs,
k: (
TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v
)
for k, v in tc.kwargs.items()
},
"comparison_target": tc.comparison_target, "comparison_target": tc.comparison_target,
"tolerance": tc.tolerance, "tolerance": tc.tolerance,
} }
if tc.output_spec: # --- D. Inject Result ---
case_data["output_spec"] = TestReporter._spec_to_dict(tc.output_spec)
if hasattr(tc, "output_specs") and tc.output_specs:
case_data["output_specs"] = [
TestReporter._spec_to_dict(s) for s in tc.output_specs
]
# 2. Inject Result (Dynamic info) directly into the case
res = results_map.get(idx)
if res: if res:
case_data["result"] = TestReporter._fmt_result(res) case_data["result"] = TestReporter._fmt_result(res)
else: else:
case_data["result"] = {"status": {"success": False, "error": "No result"}} case_data["result"] = {"status": {"success": False, "error": "No result"}}
processed_cases.append(case_data) grouped_entries[dev_id]["testcases"].append(case_data)
# Global Arguments
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
return { return list(grouped_entries.values())
"operator": op_name,
"device": device,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": processed_cases
}
@staticmethod @staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]): def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
""" """
Saves the report list to a JSON file with compact formatting. Saves the report list to a JSON file with specific custom formatting
""" """
directory, filename = os.path.split(save_path) directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename) name, ext = os.path.splitext(filename)
timestamp = time.strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}") final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define indentation levels for cleaner code
indent_4 = ' ' * 4
indent_8 = ' ' * 8
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}") print(f"💾 Saving to: {final_path}")
try: try:
with open(final_path, "w", encoding="utf-8") as f: with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n") f.write("[\n")
for i, entry in enumerate(total_results): for i, entry in enumerate(total_results):
f.write(" {\n") f.write(f"{indent_4}{{\n")
keys = list(entry.keys()) keys = list(entry.keys())
for j, key in enumerate(keys): for j, key in enumerate(keys):
# Special Handling for list fields: vertical expansion val = entry[key]
if key in ["testcases"] and isinstance(entry[key], list): comma = "," if j < len(keys) - 1 else ""
f.write(f' "{key}": [\n')
sub_list = entry[key] # -------------------------------------------------
for c_idx, c_item in enumerate(sub_list): # Special Handling for 'testcases' list formatting
c_str = json.dumps(c_item, ensure_ascii=False) # -------------------------------------------------
comma = "," if c_idx < len(sub_list) - 1 else "" if key == "testcases" and isinstance(val, list):
f.write(f" {c_str}{comma}\n") f.write(f'{indent_8}"{key}": [\n')
list_comma = "," if j < len(keys) - 1 else "" for c_idx, case_item in enumerate(val):
f.write(f" ]{list_comma}\n") f.write(f"{indent_12}{{\n")
case_keys = list(case_item.keys())
for k_idx, c_key in enumerate(case_keys):
c_val = case_item[c_key]
# [Logic A] Skip fields we merged manually after 'kwargs'
if c_key in ["comparison_target", "tolerance"]:
continue
# Check comma for standard logic (might be overridden below)
c_comma = "," if k_idx < len(case_keys) - 1 else ""
# [Logic B] Handle 'kwargs' + Grouped Fields
if c_key == "kwargs":
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
line_comma = "," if remaining_keys else ""
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if c_key == "inputs" and isinstance(c_val, list):
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else:
c_val_str = json.dumps(c_val, ensure_ascii=False)
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
close_comma = "," if c_idx < len(val) - 1 else ""
f.write(f"{indent_12}}}{close_comma}\n")
f.write(f"{indent_8}]{comma}\n")
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else: else:
# Standard compact formatting
k_str = json.dumps(key, ensure_ascii=False) k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(entry[key], ensure_ascii=False) v_str = json.dumps(val, ensure_ascii=False)
comma = "," if j < len(keys) - 1 else "" f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
f.write(f" {k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1: if i < len(total_results) - 1:
f.write(" },\n") f.write(f"{indent_4}}},\n")
else: else:
f.write(" }\n") f.write(f"{indent_4}}}\n")
f.write("]\n") f.write("]\n")
print(f" ✅ Saved (Structure Matched).") print(f" ✅ Saved (Structure Matched).")
except Exception as e: except Exception as e:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}") print(f" ❌ Save failed: {e}")
# --- Internal Helpers --- # --- Internal Helpers ---
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = '{' if is_dict else '['
close_char = '}' if is_dict else ']'
f.write(f'{indent}"{key}": {open_char}')
# Normalize items for iteration
if is_dict:
items = list(value.items())
else:
items = value # List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
# Format individual item string
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = (i == len(items) - 1)
item_comma = "" if is_last else ", "
# Predict new length: current + item + comma
if current_len + len(item_str) + len(item_comma) > 180:
# Wrap to new line
f.write(f'\n{sub_indent}')
current_len = len(sub_indent)
f.write(f'{item_str}{item_comma}')
current_len += len(item_str) + len(item_comma)
f.write(f'{close_char}{close_comma}\n')
@staticmethod @staticmethod
def _spec_to_dict(s): def _spec_to_dict(s):
...@@ -147,14 +272,7 @@ class TestReporter: ...@@ -147,14 +272,7 @@ class TestReporter:
if not (is_dataclass(res) or hasattr(res, "success")): if not (is_dataclass(res) or hasattr(res, "success")):
return str(res) return str(res)
device_id_map = { get_time = lambda k: round(getattr(res, k, 0.0), 4)
v: k
for k, v in vars(InfiniDeviceEnum).items()
if not k.startswith("_")
}
raw_id = getattr(res, "device", None)
dev_str = device_id_map.get(raw_id, str(raw_id))
return { return {
"status": { "status": {
...@@ -163,13 +281,12 @@ class TestReporter: ...@@ -163,13 +281,12 @@ class TestReporter:
}, },
"perf_ms": { "perf_ms": {
"torch": { "torch": {
"host": round(getattr(res, "torch_host_time", 0.0), 4), "host": get_time("torch_host_time"),
"device": round(getattr(res, "torch_device_time", 0.0), 4), "device": get_time("torch_device_time"),
}, },
"infinicore": { "infinicore": {
"host": round(getattr(res, "infini_host_time", 0.0), 4), "host": get_time("infini_host_time"),
"device": round(getattr(res, "infini_device_time", 0.0), 4), "device": get_time("infini_device_time"),
}, },
}, },
"device": dev_str,
} }
...@@ -53,6 +53,9 @@ class GenericTestRunner: ...@@ -53,6 +53,9 @@ class GenericTestRunner:
# summary_passed returns True if no tests failed (skipped/partial are OK) # summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed = runner.print_summary() summary_passed = runner.print_summary()
if getattr(self.args, 'save', None):
self._save_report(runner)
# Both conditions must be True for overall success # Both conditions must be True for overall success
# - has_no_failures: no test failures during execution # - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures # - summary_passed: summary confirms no failures
...@@ -65,10 +68,7 @@ class GenericTestRunner: ...@@ -65,10 +68,7 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures) 0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed 1: One or more tests failed
""" """
success, runner = self.run() success, runner = self.run()
if getattr(self.args, 'save', None):
self._save_report(runner)
sys.exit(0 if success else 1) sys.exit(0 if success else 1)
...@@ -77,21 +77,8 @@ class GenericTestRunner: ...@@ -77,21 +77,8 @@ class GenericTestRunner:
Helper method to collect metadata and trigger report saving. Helper method to collect metadata and trigger report saving.
""" """
try: try:
# 1. Infer active device string dynamically
from .devices import InfiniDeviceEnum
# Get actual device IDs used (e.g. [0, 1])
device_ids = get_test_devices(self.args)
# Map IDs to Names (e.g. {0: "CPU", 1: "NVIDIA"}) # 1. Prepare metadata (Paths)
id_to_name = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith('_')}
# Convert list of IDs to list of Names
device_names = [id_to_name.get(d_id, str(d_id)) for d_id in device_ids]
device_str = ", ".join(device_names) if device_names else "CPU"
# 2. Prepare metadata (Paths)
# Try to infer from source code first
t_path = self._infer_op_path(self.operator_test.torch_operator, "torch") t_path = self._infer_op_path(self.operator_test.torch_operator, "torch")
i_path = self._infer_op_path(self.operator_test.infinicore_operator, "infinicore") i_path = self._infer_op_path(self.operator_test.infinicore_operator, "infinicore")
...@@ -100,18 +87,17 @@ class GenericTestRunner: ...@@ -100,18 +87,17 @@ class GenericTestRunner:
"infinicore": i_path "infinicore": i_path
} }
# 3. Generate Report Entry # 2. Generate Report Entries
entry = TestReporter.prepare_report_entry( entries = TestReporter.prepare_report_entry(
op_name=self.operator_test.operator_name, op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases, test_cases=self.operator_test.test_cases,
args=self.args, args=self.args,
op_paths=op_paths, op_paths=op_paths,
device=device_str,
results_list=runner.test_results results_list=runner.test_results
) )
# 4. Save to File # 4. Save to File
TestReporter.save_all_results(self.args.save, [entry]) TestReporter.save_all_results(self.args.save, entries)
except Exception as e: except Exception as e:
import traceback; traceback.print_exc() import traceback; traceback.print_exc()
......
...@@ -5,7 +5,7 @@ import traceback ...@@ -5,7 +5,7 @@ import traceback
from pathlib import Path from pathlib import Path
import importlib.util import importlib.util
from framework import get_hardware_args_group from framework import get_hardware_args_group, add_common_test_args
def find_ops_directory(location=None): def find_ops_directory(location=None):
...@@ -650,24 +650,9 @@ def main(): ...@@ -650,24 +650,9 @@ def main():
action="store_true", action="store_true",
help="List all available test files without running them", help="List all available test files without running them",
) )
parser.add_argument(
"--verbose", # Call common method to add shared arguments (bench, debug, verbose, save...)
action="store_true", add_common_test_args(parser)
help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode to debug value mismatches",
)
parser.add_argument(
"--bench",
nargs="?",
const="both",
choices=["host", "device", "both"],
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
)
get_hardware_args_group(parser) get_hardware_args_group(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment