Commit 5aa850af authored by baominghelly's avatar baominghelly
Browse files

func behavior fix && rename class name

parent 0d58c820
from .base import TestConfig, TestRunner, BaseOperatorTest from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult from .entities import TestCase
from .benchmark import BenchmarkUtils, BenchmarkResult from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import ( from .config import (
add_common_test_args, add_common_test_args,
...@@ -11,6 +11,9 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype ...@@ -11,6 +11,9 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .runner import GenericTestRunner from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
from .types import TestTiming, OperatorTestResult, TestResult
from .driver import TestDriver
from .printer import ConsolePrinter
from .utils import ( from .utils import (
compare_results, compare_results,
create_test_comparator, create_test_comparator,
...@@ -38,6 +41,10 @@ __all__ = [ ...@@ -38,6 +41,10 @@ __all__ = [
"TestResult", "TestResult",
"TestRunner", "TestRunner",
"TestReporter", "TestReporter",
"TestTiming",
"OperatorTestResult",
"TestDriver",
"ConsolePrinter",
# Core functions # Core functions
"add_common_test_args", "add_common_test_args",
"compare_results", "compare_results",
......
...@@ -8,7 +8,8 @@ import infinicore ...@@ -8,7 +8,8 @@ import infinicore
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .test_case import TestCase, TestResult from .entities import TestCase
from .types import TestResult
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
......
...@@ -60,37 +60,3 @@ def to_infinicore_dtype(torch_dtype): ...@@ -60,37 +60,3 @@ def to_infinicore_dtype(torch_dtype):
return infinicore.complex128 return infinicore.complex128
else: else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}") raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
@dataclass
class TestTiming:
"""Stores performance testing timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
operators_tested: int = 0
@dataclass
class SingleTestResult:
"""Stores the execution results of a single test file."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0: return "✅"
if self.return_code == -2: return "⏭️"
if self.return_code == -3: return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0: return "PASSED"
if self.return_code == -2: return "SKIPPED"
if self.return_code == -3: return "PARTIAL"
return "FAILED"
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import importlib.util import importlib.util
from io import StringIO from io import StringIO
from contextlib import contextmanager from contextlib import contextmanager
from .datatypes import SingleTestResult, TestTiming from .types import OperatorTestResult, TestTiming
@contextmanager @contextmanager
def capture_output(): def capture_output():
...@@ -15,9 +15,9 @@ def capture_output(): ...@@ -15,9 +15,9 @@ def capture_output():
finally: finally:
sys.stdout, sys.stderr = old_out, old_err sys.stdout, sys.stderr = old_out, old_err
class SingleTestExecutor: class TestDriver:
def run(self, file_path) -> SingleTestResult: def drive(self, file_path) -> OperatorTestResult:
result = SingleTestResult(name=file_path.stem) result = OperatorTestResult(name=file_path.stem)
try: try:
# 1. Dynamically import the module # 1. Dynamically import the module
...@@ -79,15 +79,22 @@ class SingleTestExecutor: ...@@ -79,15 +79,22 @@ class SingleTestExecutor:
def _analyze_return_code(self, result, test_results): def _analyze_return_code(self, result, test_results):
# Logic consistent with original code: determine if all passed, partially passed, or skipped # Logic consistent with original code: determine if all passed, partially passed, or skipped
if not result.success: if result.success:
result.return_code = -1 result.return_code = 0
return return
codes = [r.return_code for r in test_results] has_failures = any(r.return_code == -1 for r in test_results)
if -1 in codes: result.return_code = -1 has_partial = any(r.return_code == -3 for r in test_results)
elif -3 in codes: result.return_code = -3 has_skipped = any(r.return_code == -2 for r in test_results)
elif -2 in codes: result.return_code = -2
else: result.return_code = 0 if has_failures:
result.return_code = -1
elif has_partial:
result.return_code = -3
elif has_skipped:
result.return_code = -2
else:
result.return_code = -1
def _extract_timing(self, result, test_results): def _extract_timing(self, result, test_results):
# Accumulate timing # Accumulate timing
......
...@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple ...@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from .tensor import TensorSpec from .tensor import TensorSpec
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase: class TestCase:
"""Test case with all configuration included""" """Test case with all configuration included"""
......
...@@ -21,6 +21,18 @@ class TestDiscoverer: ...@@ -21,6 +21,18 @@ class TestDiscoverer:
files = self.scan() files = self.scan()
return sorted([f.stem for f in files]) return sorted([f.stem for f in files])
def get_raw_python_files(self):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if not self.ops_dir or not self.ops_dir.exists():
return []
files = list(self.ops_dir.glob("*.py"))
# Exclude run.py itself and __init__.py
return [f.name for f in files if f.name != "run.py" and not f.name.startswith("__")]
def scan(self, specific_ops=None): def scan(self, specific_ops=None):
"""Scans and returns a list of Path objects that meet the criteria.""" """Scans and returns a list of Path objects that meet the criteria."""
if not self.ops_dir or not self.ops_dir.exists(): if not self.ops_dir or not self.ops_dir.exists():
...@@ -29,17 +41,23 @@ class TestDiscoverer: ...@@ -29,17 +41,23 @@ class TestDiscoverer:
# 1. Find all .py files # 1. Find all .py files
files = list(self.ops_dir.glob("*.py")) files = list(self.ops_dir.glob("*.py"))
target_ops_set = set(specific_ops) if specific_ops else None
# 2. Filter out non-test files (via content check) # 2. Filter out non-test files (via content check)
valid_files = [] valid_files = []
for f in files: for f in files:
# A. Basic Name Filtering
if f.name.startswith("_") or f.name == "run.py": if f.name.startswith("_") or f.name == "run.py":
continue continue
# B. Specific Ops Filtering
if target_ops_set and f.stem not in target_ops_set:
continue
# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if self._is_operator_test(f): if self._is_operator_test(f):
valid_files.append(f) valid_files.append(f)
# 3. If specific operators are specified, filter them
if specific_ops:
return [f for f in valid_files if f.stem in specific_ops]
return valid_files return valid_files
......
# lib/printer.py
import sys
from .types import OperatorTestResult, TestTiming
class ConsolePrinter:
"""
Handles all console output logic.
Acts as the 'View' in the application structure.
"""
def list_tests(self, discoverer):
"""
Intelligently list available tests.
If no valid operators are found, it falls back to listing raw Python files
to assist with debugging (e.g., typos in class inheritance).
"""
ops_dir = discoverer.ops_dir
operators = discoverer.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
# === Fallback Debug Logic ===
raw_files = discoverer.get_raw_python_files()
if raw_files:
print(f"\n💡 Debug Hint: Found the following Python files (but they are not valid tests):")
print(f" {raw_files}")
print(" (Ensure they inherit from 'BaseOperatorTest' and contain 'infinicore')")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result, verbose=False):
"""Print single-line result in real-time."""
print(f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})")
# Only print details if verbose or if the test failed/had output
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or verbose:
print("-" * 40)
def print_summary(
self,
results,
cumulative_timing,
ops_dir,
total_expected=0,
verbose=False,
bench_mode="both"
):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped: print(f"Skipped: {len(skipped)}")
if partial: print(f"Partial: {len(partial)}")
# 1. Print Benchmark data
if cumulative_timing:
# Call the internal helper method
self._print_timing(cumulative_timing, bench_mode=bench_mode)
# 2. Print Detailed Lists
# PASSED
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
# FAILED
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
# SKIPPED
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
# PARTIAL
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Restore Success Rate
if total > 0:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if verbose and failed:
print(f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:")
# Show first 3 failed operators to avoid spamming
for r in failed[:3]:
# Construct file path: ops_dir / filename.py
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
# --- Internal Helpers ---
def _print_op_list(self, title, result_list):
"""Helper to print a formatted list of operator names."""
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
# Group by 10 per line
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
def _print_timing(self, t, bench_mode="both"):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print(f"{'-'*40}")
# Restore Operators Tested field using the dataclass field
if hasattr(t, 'operators_tested') and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY:")
print(f" Operators Tested: {t.operators_tested}")
# Restore detailed Host/Device distinction
if bench_mode in ["host", "both"]:
print(
f" PyTorch Host Total Time: {t.torch_host:12.3f} ms"
)
print(
f" InfiniCore Host Total Time: {t.infini_host:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {t.torch_device:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {t.infini_device:12.3f} ms"
)
print(f"{'-'*40}")
...@@ -229,110 +229,8 @@ class TestReporter: ...@@ -229,110 +229,8 @@ class TestReporter:
except Exception as e: except Exception as e:
import traceback; traceback.print_exc() import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}") print(f" ❌ Save failed: {e}")
@staticmethod
def print_header(ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
@staticmethod
def print_live_result(result, verbose=False):
"""Print single-line result in real-time."""
print(f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})")
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or verbose:
print("-" * 40)
@staticmethod
def print_summary(results, cumulative_timing, total_expected=0):
"""Prints the final comprehensive test summary and statistics, ensuring consistency with original output."""
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped: print(f"Skipped: {len(skipped)}")
if partial: print(f"Partial: {len(partial)}")
# 1. Print Benchmark data
if cumulative_timing:
# Assuming bench_mode is "both" for simplicity in this file, or passed via a config
# We call the modified _print_timing to handle the display logic.
TestReporter._print_timing(cumulative_timing, bench_mode="both")
# 2. Restore PASSED OPERATORS list
if passed:
print(f"\n✅ PASSED OPERATORS ({len(passed)}):")
# Print operators, grouped (assuming 10 per line as per the old pattern)
operators = [r.name for r in passed]
for i in range(0, len(operators), 10):
print(" " + ", ".join(operators[i : i + 10]))
else:
print(f"\n✅ PASSED OPERATORS: None")
# 3. Restore Success Rate
if total > 0:
# Calculate success rate based on actually executed tests (excluding skipped)
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
return len(failed) == 0
# --- Internal Helpers --- # --- Internal Helpers ---
@staticmethod
def _print_timing(t, bench_mode="both"):
"""Prints detailed timing breakdown for host and device, based on bench_mode."""
print(f"{'-'*40}")
# Restore Operators Tested field using the new dataclass field
if hasattr(t, 'operators_tested'):
print(f"BENCHMARK SUMMARY:")
print(f" Operators Tested: {t.operators_tested}")
# Restore detailed Host/Device distinction
if bench_mode in ["host", "both"]:
print(
f" PyTorch Host Total Time: {t.torch_host:12.3f} ms"
)
print(
f" InfiniCore Host Total Time: {t.infini_host:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {t.torch_device:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {t.infini_device:12.3f} ms"
)
print(f"{'-'*40}")
@staticmethod @staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""): def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
""" """
......
from dataclasses import dataclass, field
from typing import Any
# TODO: Rename it, current class name is abstract.
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorTestResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0: return "✅"
if self.return_code == -2: return "⏭️"
if self.return_code == -3: return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0: return "PASSED"
if self.return_code == -2: return "SKIPPED"
if self.return_code == -3: return "PARTIAL"
return "FAILED"
...@@ -4,14 +4,93 @@ from pathlib import Path ...@@ -4,14 +4,93 @@ from pathlib import Path
# Import components from the unified framework package # Import components from the unified framework package
from framework.loader import TestDiscoverer from framework.loader import TestDiscoverer
from framework.executor import SingleTestExecutor from framework.driver import TestDriver
from framework.reporter import TestReporter from framework.printer import ConsolePrinter
from framework.datatypes import TestTiming from framework.types import TestTiming
from framework import get_hardware_args_group, add_common_test_args from framework import get_hardware_args_group, add_common_test_args
def generate_help_epilog(ops_dir=None):
"""
Generate dynamic help epilog containing available operators and hardware platforms.
Maintains the original output format for backward compatibility.
"""
# === Adapter: Use TestDiscoverer to get operator list ===
# Temporarily instantiate a Discoverer just to fetch the list
discoverer = TestDiscoverer(ops_dir)
operators = discoverer.get_available_operators()
# Build epilog text (fully replicating original logic)
epilog_parts = []
# Examples section
epilog_parts.append("Examples:")
epilog_parts.append(" # Run all operator tests on CPU")
epilog_parts.append(" python run.py --cpu")
epilog_parts.append("")
epilog_parts.append(" # Run specific operators")
epilog_parts.append(" python run.py --ops add matmul --nvidia")
epilog_parts.append("")
epilog_parts.append(" # Run with debug mode on multiple devices")
epilog_parts.append(" python run.py --cpu --nvidia --debug")
epilog_parts.append("")
epilog_parts.append(
" # Run with verbose mode to stop on first error with full traceback"
)
epilog_parts.append(" python run.py --cpu --nvidia --verbose")
epilog_parts.append("")
epilog_parts.append(" # Run with benchmarking (both host and device timing)")
epilog_parts.append(" python run.py --cpu --bench")
epilog_parts.append("")
epilog_parts.append(" # Run with host timing only")
epilog_parts.append(" python run.py --nvidia --bench host")
epilog_parts.append("")
epilog_parts.append(" # Run with device timing only")
epilog_parts.append(" python run.py --nvidia --bench device")
epilog_parts.append("")
epilog_parts.append(" # List available tests without running")
epilog_parts.append(" python run.py --list")
epilog_parts.append("")
# Available operators section
if operators:
epilog_parts.append("Available Operators:")
# Group operators for better display
operators_per_line = 4
for i in range(0, len(operators), operators_per_line):
line_ops = operators[i : i + operators_per_line]
epilog_parts.append(f" {', '.join(line_ops)}")
epilog_parts.append("")
else:
epilog_parts.append("Available Operators: (none detected)")
epilog_parts.append("")
# Additional notes
epilog_parts.append("Note:")
epilog_parts.append(
" - Use '--' to pass additional arguments to individual test scripts"
)
epilog_parts.append(
" - Operators are automatically discovered from the ops directory"
)
epilog_parts.append(
" - --bench mode now shows cumulative timing across all operators"
)
epilog_parts.append(
" - --bench host/device/both controls host/device timing measurement"
)
epilog_parts.append(
" - --verbose mode stops execution on first error and shows full traceback"
)
return "\n".join(epilog_parts)
def main(): def main():
"""Main entry point for the InfiniCore Operator Test Runner.""" """Main entry point for the InfiniCore Operator Test Runner."""
parser = argparse.ArgumentParser(description="Run InfiniCore operator tests across multiple hardware platforms") parser = argparse.ArgumentParser(
description="Run InfiniCore operator tests across multiple hardware platforms",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=generate_help_epilog()
)
parser.add_argument("--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)") parser.add_argument("--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)")
parser.add_argument("--ops", nargs="+", help="Run specific operators only (e.g., --ops add matmul)") parser.add_argument("--ops", nargs="+", help="Run specific operators only (e.g., --ops add matmul)")
parser.add_argument("--list", action="store_true", help="List all available test files without running them") parser.add_argument("--list", action="store_true", help="List all available test files without running them")
...@@ -20,7 +99,10 @@ def main(): ...@@ -20,7 +99,10 @@ def main():
add_common_test_args(parser) add_common_test_args(parser)
get_hardware_args_group(parser) get_hardware_args_group(parser)
args, _ = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
# Show what extra arguments will be passed
if unknown_args:
print(f"Passing extra arguments to test scripts: {unknown_args}")
# 1. Discovery # 1. Discovery
discoverer = TestDiscoverer(args.ops_dir) discoverer = TestDiscoverer(args.ops_dir)
...@@ -28,25 +110,62 @@ def main(): ...@@ -28,25 +110,62 @@ def main():
print("Available operators:", discoverer.get_available_operators()) print("Available operators:", discoverer.get_available_operators())
return return
test_files = discoverer.scan(args.ops) if args.verbose:
print(f"Verbose mode: ENABLED (will stop on first error with full traceback)")
if args.bench:
bench_mode = args.bench if args.bench != "both" else "both"
print(f"Benchmark mode: {bench_mode.upper()} timing")
target_ops = None
if args.ops:
# Get all available operator names
available_ops = set(discoverer.get_available_operators())
requested_ops = set(args.ops)
# Classify using set operations
valid_ops = list(requested_ops & available_ops) # Intersection: Valid ops
invalid_ops = list(requested_ops - available_ops) # Difference: Invalid ops
# Warn if there are invalid operators
if invalid_ops:
print(f"⚠️ Warning: The following requested operators were not found:")
print(f" {', '.join(invalid_ops)}")
print(f" (Use --list to see available operators)")
if not valid_ops:
# Case A: User input provided, but ALL were invalid.
print(f"⚠️ No valid operators remained from your list.")
print(f"🔄 Fallback: Proceeding to run ALL available tests...")
target_ops = None
else:
# Case B: At least some valid operators found.
print(f"🎯 Targeted operators: {', '.join(valid_ops)}")
target_ops = valid_ops
target_ops = valid_ops
test_files = discoverer.scan(target_ops)
if not test_files: if not test_files:
print("No tests found.") print("No tests found.")
sys.exit(0) sys.exit(0)
# 2. Preparation # 2. Preparation
executor = SingleTestExecutor() dirver = TestDriver()
cumulative_timing = TestTiming() cumulative_timing = TestTiming()
printer = ConsolePrinter()
results = [] results = []
TestReporter.print_header(discoverer.ops_dir, len(test_files)) printer.print_header(discoverer.ops_dir, len(test_files))
# 3. Execution Loop # 3. Execution Loop
for f in test_files: for f in test_files:
result = executor.run(f) result = dirver.drive(f)
results.append(result) results.append(result)
# Real-time reporting and printing of stdout # Real-time reporting and printing of stdout
TestReporter.print_live_result(result, verbose=args.verbose) printer.print_live_result(result, verbose=args.verbose)
# Accumulate timing # Accumulate timing
if result.success: if result.success:
...@@ -61,10 +180,12 @@ def main(): ...@@ -61,10 +180,12 @@ def main():
break break
# 4. Final Report & Save # 4. Final Report & Save
all_passed = TestReporter.print_summary( all_passed = printer.print_summary(
results, results,
cumulative_timing if args.bench else None, cumulative_timing if args.bench else None,
total_expected=len(test_files) ops_dir=discoverer.ops_dir,
total_expected=len(test_files),
verbose=args.verbose
) )
sys.exit(0 if all_passed else 1) sys.exit(0 if all_passed else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment