"superbench/benchmarks/vscode:/vscode.git/clone" did not exist on "80dcc8aaec4ffcee8bc6d9ce51daefd8962af2a0"
Commit 0d58c820 authored by baominghelly's avatar baominghelly
Browse files

issue/787 - Split run ops test logic and fix kwargs name in report

parent 726eacf8
import torch
import infinicore
from dataclasses import dataclass, field
def to_torch_dtype(infini_dtype):
"""Convert infinicore data type to PyTorch data type"""
......@@ -60,3 +60,37 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.complex128
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
@dataclass
class TestTiming:
"""Stores performance testing timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
operators_tested: int = 0
@dataclass
class SingleTestResult:
"""Stores the execution results of a single test file."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0: return "✅"
if self.return_code == -2: return "⏭️"
if self.return_code == -3: return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0: return "PASSED"
if self.return_code == -2: return "SKIPPED"
if self.return_code == -3: return "PARTIAL"
return "FAILED"
import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .datatypes import SingleTestResult, TestTiming
@contextmanager
def capture_output():
"""Context manager: captures stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield new_out, new_err
finally:
sys.stdout, sys.stderr = old_out, old_err
class SingleTestExecutor:
def run(self, file_path) -> SingleTestResult:
result = SingleTestResult(name=file_path.stem)
try:
# 1. Dynamically import the module
module = self._import_module(file_path)
# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
raise ImportError("No BaseOperatorTest subclass found")
test_instance = test_class()
runner_class = module.GenericTestRunner
runner = runner_class(test_instance.__class__)
# 4. Execute and capture output
with capture_output() as (out, err):
success, internal_runner = runner.run()
# 5. Populate results
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()
# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
self._analyze_return_code(result, test_results)
self._extract_timing(result, test_results)
except Exception as e:
result.success = False
result.error_message = str(e)
result.stderr += f"\nExecutor Error: {str(e)}"
result.return_code = -1
return result
def _import_module(self, path):
module_name = f"op_test_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
raise ImportError(f"Could not load spec from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _find_test_class(self, module):
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and hasattr(attr, "__bases__"):
# Simple check for base class name
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None
def _analyze_return_code(self, result, test_results):
# Logic consistent with original code: determine if all passed, partially passed, or skipped
if not result.success:
result.return_code = -1
return
codes = [r.return_code for r in test_results]
if -1 in codes: result.return_code = -1
elif -3 in codes: result.return_code = -3
elif -2 in codes: result.return_code = -2
else: result.return_code = 0
def _extract_timing(self, result, test_results):
# Accumulate timing
t = result.timing
t.torch_host = sum(r.torch_host_time for r in test_results)
t.torch_device = sum(r.torch_device_time for r in test_results)
t.infini_host = sum(r.infini_host_time for r in test_results)
t.infini_device = sum(r.infini_device_time for r in test_results)
from pathlib import Path
class TestDiscoverer:
def __init__(self, ops_dir_path=None):
self.ops_dir = self._resolve_dir(ops_dir_path)
def _resolve_dir(self, path):
if path:
p = Path(path)
if p.exists(): return p
# Default fallback logic: 'ops' directory under the parent of the current file's parent.
# Note: Since this file is in 'framework/', we look at parent.parent.
# It is recommended to pass an explicit path in run.py.
fallback = Path(__file__).parent.parent / "ops"
return fallback if fallback.exists() else None
def get_available_operators(self):
"""Returns a list of names of all available operators."""
if not self.ops_dir: return []
files = self.scan()
return sorted([f.stem for f in files])
def scan(self, specific_ops=None):
"""Scans and returns a list of Path objects that meet the criteria."""
if not self.ops_dir or not self.ops_dir.exists():
return []
# 1. Find all .py files
files = list(self.ops_dir.glob("*.py"))
# 2. Filter out non-test files (via content check)
valid_files = []
for f in files:
if f.name.startswith("_") or f.name == "run.py":
continue
if self._is_operator_test(f):
valid_files.append(f)
# 3. If specific operators are specified, filter them
if specific_ops:
return [f for f in valid_files if f.stem in specific_ops]
return valid_files
def _is_operator_test(self, file_path):
"""Checks if the file content contains operator test characteristics."""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
)
except:
return False
......@@ -61,35 +61,61 @@ class TestReporter:
# --- B. Build Kwargs ---
display_kwargs = {}
# B1. Process existing kwargs
for k, v in tc.kwargs.items():
# Handle Inplace: "out": index -> "out": "input_name"
# 1. Handle Inplace output index: "out": 0 -> "out": "in_0" / "a_spec"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
display_kwargs[k] = tc.inputs[v].name
# Prioritize the input's name; otherwise, default to index-based name
display_kwargs[k] = getattr(tc.inputs[v], "name", None) or f"in_{v}"
else:
display_kwargs[k] = f"Invalid_Index_{v}"
# 2. Handle TensorSpec objects
elif isinstance(v, TensorSpec):
spec_dict = TestReporter._spec_to_dict(v)
# If the object has a name, explicitly overwrite it; otherwise, keep original
if getattr(v, "name", None):
spec_dict["name"] = v.name
display_kwargs[k] = spec_dict
# 3. Direct assignment for other types
else:
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
display_kwargs[k] = v
# B2. Inject Outputs into Kwargs
if hasattr(tc, "output_specs") and tc.output_specs:
# --- B2. Inject Outputs ---
# Handle output list (output_specs)
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
elif tc.output_spec:
if "out" not in display_kwargs:
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
out_dict = TestReporter._spec_to_dict(spec)
# Prioritize intrinsic name; otherwise, default to "out_i"
out_dict["name"] = getattr(spec, "name", None) or f"out_{i}"
display_kwargs[f"out_{i}"] = out_dict
# Handle single output (output_spec), preventing overwrite of existing "out"
elif tc.output_spec and "out" not in display_kwargs:
out_dict = TestReporter._spec_to_dict(tc.output_spec)
# Prioritize intrinsic name; otherwise, default to "out" (fixes null issue)
out_dict["name"] = getattr(tc.output_spec, "name", "out")
display_kwargs["out"] = out_dict
# --- C. Build Inputs ---
# Iterate inputs: prioritize original name, fallback to "in_i"
processed_inputs = []
for i, inp in enumerate(tc.inputs):
inp_dict = TestReporter._spec_to_dict(inp)
# Simplified logic: Use "name" attribute if present and non-empty, else use f"in_{i}"
inp_dict["name"] = getattr(inp, "name", None) or f"in_{i}"
processed_inputs.append(inp_dict)
# --- C. Build Test Case Dictionary ---
case_data = {
"description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
# --- D. Inject Result ---
if res:
case_data["result"] = TestReporter._fmt_result(res)
......@@ -117,7 +143,7 @@ class TestReporter:
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
......@@ -125,8 +151,8 @@ class TestReporter:
for i, entry in enumerate(total_results):
f.write(f"{indent_4}{{\n")
keys = list(entry.keys())
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
......@@ -204,7 +230,109 @@ class TestReporter:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}")
@staticmethod
def print_header(ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
@staticmethod
def print_live_result(result, verbose=False):
"""Print single-line result in real-time."""
print(f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})")
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or verbose:
print("-" * 40)
@staticmethod
def print_summary(results, cumulative_timing, total_expected=0):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped: print(f"Skipped: {len(skipped)}")
if partial: print(f"Partial: {len(partial)}")
# 1. Print Benchmark data
if cumulative_timing:
# Assuming bench_mode is "both" for simplicity in this file, or passed via a config
# We call the modified _print_timing to handle the display logic.
TestReporter._print_timing(cumulative_timing, bench_mode="both")
# 2. Restore PASSED OPERATORS list
if passed:
print(f"\n✅ PASSED OPERATORS ({len(passed)}):")
# Print operators, grouped (assuming 10 per line as per the old pattern)
operators = [r.name for r in passed]
for i in range(0, len(operators), 10):
print(" " + ", ".join(operators[i : i + 10]))
else:
print(f"\n✅ PASSED OPERATORS: None")
# 3. Restore Success Rate
if total > 0:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
return len(failed) == 0
# --- Internal Helpers ---
@staticmethod
def _print_timing(t, bench_mode="both"):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print(f"{'-'*40}")
# Restore Operators Tested field using the new dataclass field
if hasattr(t, 'operators_tested'):
print(f"BENCHMARK SUMMARY:")
print(f" Operators Tested: {t.operators_tested}")
# Restore detailed Host/Device distinction
if bench_mode in ["host", "both"]:
print(
f" PyTorch Host Total Time: {t.torch_host:12.3f} ms"
)
print(
f" InfiniCore Host Total Time: {t.infini_host:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {t.torch_device:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {t.infini_device:12.3f} ms"
)
print(f"{'-'*40}")
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment