Commit 1e6ccdc9 authored by wooway777's avatar wooway777
Browse files

issue/598 - optimize run.py performance

parent 5c88cbbd
...@@ -182,9 +182,9 @@ pip install . -e ...@@ -182,9 +182,9 @@ pip install . -e
```bash ```bash
# 测试单算子 # 测试单算子
python test/infinicore/ops/[operator].py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon] python test/infinicore/ops/[operator].py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon]
# 测试全部算子 # 测试全部算子
python test/infinicore/run.py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun] python test/infinicore/run.py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
``` ```
使用 -h 查看更多参数。 使用 -h 查看更多参数。
......
import torch import torch
import infinicore import infinicore
import traceback import traceback
from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Tuple
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
...@@ -15,6 +15,18 @@ from .utils import ( ...@@ -15,6 +15,18 @@ from .utils import (
) )
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_time: float = 0.0
infini_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase: class TestCase:
"""Test case with all configuration included""" """Test case with all configuration included"""
...@@ -23,11 +35,11 @@ class TestCase: ...@@ -23,11 +35,11 @@ class TestCase:
inputs, inputs,
kwargs=None, kwargs=None,
output_spec=None, output_spec=None,
output_specs=None,
comparison_target=None, comparison_target=None,
description="", description="",
tolerance=None, tolerance=None,
output_count=1, output_count=1,
output_specs=None,
): ):
""" """
Initialize a test case with complete configuration Initialize a test case with complete configuration
...@@ -248,6 +260,8 @@ class TestRunner: ...@@ -248,6 +260,8 @@ class TestRunner:
"infinicore_total": 0.0, "infinicore_total": 0.0,
"per_test_case": {}, # Store timing per test case "per_test_case": {}, # Store timing per test case
} }
# Store test results
self.test_results = []
def run_tests(self, devices, test_func, test_type="Test"): def run_tests(self, devices, test_func, test_type="Test"):
""" """
...@@ -270,25 +284,30 @@ class TestRunner: ...@@ -270,25 +284,30 @@ class TestRunner:
try: try:
print(f"{test_case}") print(f"{test_case}")
# Execute test and get result status # Execute test and get TestResult object
success, status = test_func(device, test_case, self.config) test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
# Handle different test statuses # Handle different test statuses based on return_code
if status == "passed": if test_result.return_code == 0: # Success
self.passed_tests.append( self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}" f"{test_case} - {InfiniDeviceNames[device]}"
) )
print(f"\033[92m✓\033[0m Passed") print(f"\033[92m✓\033[0m Passed")
elif status == "skipped": elif test_result.return_code == -1:
# Test was skipped due to both operators not being implemented fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - Test terminated in verbose mode."
self.failed_tests.append(fail_msg)
elif test_result.return_code == -2: # Skipped
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented" skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
self.skipped_tests.append(skip_msg) self.skipped_tests.append(skip_msg)
elif status == "partial": print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
# Test was partially executed (one operator not implemented) elif test_result.return_code == -3: # Partial
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented" partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
self.partial_tests.append(partial_msg) self.partial_tests.append(partial_msg)
print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
# Failed tests are handled in the exception handler below if self.config.verbose and test_result.return_code != 0:
return False
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
...@@ -296,7 +315,16 @@ class TestRunner: ...@@ -296,7 +315,16 @@ class TestRunner:
) )
print(f"\033[91m✗\033[0m {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
success=False,
return_code=-1,
error_message=str(e),
test_case=test_case,
device=device
)
self.test_results.append(failed_result)
# In verbose mode, print full traceback and stop execution # In verbose mode, print full traceback and stop execution
if self.config.verbose: if self.config.verbose:
traceback.print_exc() traceback.print_exc()
...@@ -305,8 +333,7 @@ class TestRunner: ...@@ -305,8 +333,7 @@ class TestRunner:
if self.config.debug: if self.config.debug:
raise raise
# Return True if no tests failed (skipped/partial tests don't count as failures) return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
return len(self.failed_tests) == 0
def print_summary(self): def print_summary(self):
""" """
...@@ -377,6 +404,10 @@ class TestRunner: ...@@ -377,6 +404,10 @@ class TestRunner:
) )
print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x") print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x")
def get_test_results(self):
"""Get all test results"""
return self.test_results
class BaseOperatorTest(ABC): class BaseOperatorTest(ABC):
"""Base operator test""" """Base operator test"""
...@@ -480,11 +511,17 @@ class BaseOperatorTest(ABC): ...@@ -480,11 +511,17 @@ class BaseOperatorTest(ABC):
config: Test configuration config: Test configuration
Returns: Returns:
tuple: (success, status) where: TestResult: Test result object containing status and timing information
success: bool indicating if test passed
status: str describing test status ("passed", "skipped", "partial")
""" """
device_str = torch_device_map[device] device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
device=device
)
# Prepare inputs and kwargs with actual tensors # Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
...@@ -559,7 +596,10 @@ class BaseOperatorTest(ABC): ...@@ -559,7 +596,10 @@ class BaseOperatorTest(ABC):
except NotImplementedError: except NotImplementedError:
if config.verbose: if config.verbose:
traceback.print_exc() traceback.print_exc()
return False # Stop test execution immediately # Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "torch_operator not implemented"
return test_result
torch_implemented = False torch_implemented = False
torch_result = None torch_result = None
...@@ -570,26 +610,24 @@ class BaseOperatorTest(ABC): ...@@ -570,26 +610,24 @@ class BaseOperatorTest(ABC):
except NotImplementedError: except NotImplementedError:
if config.verbose: if config.verbose:
traceback.print_exc() traceback.print_exc()
return False # Stop test execution immediately # Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "infinicore_operator not implemented"
return test_result
infini_implemented = False infini_implemented = False
infini_result = None infini_result = None
# Skip if neither operator is implemented # Skip if neither operator is implemented
if not torch_implemented and not infini_implemented: if not torch_implemented and not infini_implemented:
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped") test_result.return_code = -2 # Skipped
return False, "skipped" return test_result
# Single operator execution without comparison # Single operator execution without comparison
if not torch_implemented or not infini_implemented: if not torch_implemented or not infini_implemented:
missing_op = ( test_result.return_code = -3 # Partial
"torch_operator" if not torch_implemented else "infinicore_operator" # Run benchmarking for partial tests if enabled
)
print(
f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
)
if config.bench: if config.bench:
self._run_benchmarking( torch_time, infini_time = self._run_benchmarking(
config, config,
device_str, device_str,
torch_implemented, torch_implemented,
...@@ -601,8 +639,9 @@ class BaseOperatorTest(ABC): ...@@ -601,8 +639,9 @@ class BaseOperatorTest(ABC):
test_case.output_count, test_case.output_count,
comparison_target, comparison_target,
) )
return False, "partial" test_result.torch_time = torch_time
test_result.infini_time = infini_time
return test_result
# ========================================================================== # ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC # MULTIPLE OUTPUTS COMPARISON LOGIC
# ========================================================================== # ==========================================================================
...@@ -711,7 +750,7 @@ class BaseOperatorTest(ABC): ...@@ -711,7 +750,7 @@ class BaseOperatorTest(ABC):
# UNIFIED BENCHMARKING LOGIC # UNIFIED BENCHMARKING LOGIC
# ========================================================================== # ==========================================================================
if config.bench: if config.bench:
self._run_benchmarking( torch_time, infini_time = self._run_benchmarking(
config, config,
device_str, device_str,
True, True,
...@@ -723,9 +762,13 @@ class BaseOperatorTest(ABC): ...@@ -723,9 +762,13 @@ class BaseOperatorTest(ABC):
test_case.output_count, test_case.output_count,
comparison_target, comparison_target,
) )
test_result.torch_time = torch_time
test_result.infini_time = infini_time
# Test passed successfully # Test passed successfully
return True, "passed" test_result.success = True
test_result.return_code = 0
return test_result
def _run_benchmarking( def _run_benchmarking(
self, self,
...@@ -742,8 +785,10 @@ class BaseOperatorTest(ABC): ...@@ -742,8 +785,10 @@ class BaseOperatorTest(ABC):
): ):
""" """
Unified benchmarking logic with timing accumulation Unified benchmarking logic with timing accumulation
"""
Returns:
tuple: (torch_time, infini_time) timing results
"""
# Initialize timing variables # Initialize timing variables
torch_time = 0.0 torch_time = 0.0
infini_time = 0.0 infini_time = 0.0
...@@ -809,3 +854,5 @@ class BaseOperatorTest(ABC): ...@@ -809,3 +854,5 @@ class BaseOperatorTest(ABC):
# Accumulate total times # Accumulate total times
config._test_runner.benchmark_times["torch_total"] += torch_time config._test_runner.benchmark_times["torch_total"] += torch_time
config._test_runner.benchmark_times["infinicore_total"] += infini_time config._test_runner.benchmark_times["infinicore_total"] += infini_time
return torch_time, infini_time
...@@ -100,8 +100,9 @@ Examples: ...@@ -100,8 +100,9 @@ Examples:
# Device options using shared hardware info # Device options using shared hardware info
hardware_group = get_hardware_args_group(parser) hardware_group = get_hardware_args_group(parser)
args, unknown = parser.parse_known_args()
return parser.parse_args() return args
def get_test_devices(args): def get_test_devices(args):
......
...@@ -21,7 +21,9 @@ class GenericTestRunner: ...@@ -21,7 +21,9 @@ class GenericTestRunner:
"""Execute the complete test suite """Execute the complete test suite
Returns: Returns:
bool: True if all tests passed or were skipped/partial, False if any tests failed tuple: (success, test_runner) where:
success: bool indicating if all tests passed or were skipped/partial
test_runner: TestRunner instance with test results
""" """
config = TestConfig( config = TestConfig(
debug=self.args.debug, debug=self.args.debug,
...@@ -51,7 +53,7 @@ class GenericTestRunner: ...@@ -51,7 +53,7 @@ class GenericTestRunner:
# Both conditions must be True for overall success # Both conditions must be True for overall success
# - has_no_failures: no test failures during execution # - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures # - summary_passed: summary confirms no failures
return has_no_failures and summary_passed return (has_no_failures and summary_passed), runner
def run_and_exit(self): def run_and_exit(self):
"""Run tests and exit with appropriate status code """Run tests and exit with appropriate status code
...@@ -60,5 +62,5 @@ class GenericTestRunner: ...@@ -60,5 +62,5 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures) 0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed 1: One or more tests failed
""" """
success = self.run() success, runner = self.run()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)
...@@ -133,9 +133,9 @@ class OpTest(BaseOperatorTest): ...@@ -133,9 +133,9 @@ class OpTest(BaseOperatorTest):
"""PyTorch ELU implementation""" """PyTorch ELU implementation"""
return torch.nn.functional.elu(*args, **kwargs) return torch.nn.functional.elu(*args, **kwargs)
def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs): # def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
"""InfiniCore ELU implementation""" # """InfiniCore ELU implementation"""
return None # return None
def main(): def main():
......
...@@ -103,7 +103,7 @@ def parse_test_cases(): ...@@ -103,7 +103,7 @@ def parse_test_cases():
return test_cases return test_cases
class MultiMarginLossOpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""MultiMarginLoss operator test with device handling""" """MultiMarginLoss operator test with device handling"""
def __init__(self): def __init__(self):
...@@ -116,9 +116,9 @@ class MultiMarginLossOpTest(BaseOperatorTest): ...@@ -116,9 +116,9 @@ class MultiMarginLossOpTest(BaseOperatorTest):
"""PyTorch multi_margin_loss implementation with device handling""" """PyTorch multi_margin_loss implementation with device handling"""
return F.multi_margin_loss(*args, **kwargs) return F.multi_margin_loss(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs): # def infinicore_operator(self, *args, **kwargs):
"""InfiniCore multi_margin_loss implementation""" # """InfiniCore multi_margin_loss implementation"""
return None # return None
def main(): def main():
......
import os import os
import sys import sys
import subprocess
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple, List import importlib.util
from framework import get_hardware_args_group
def find_ops_directory(location=None): def find_ops_directory(location=None):
...@@ -58,9 +59,59 @@ def get_available_operators(ops_dir): ...@@ -58,9 +59,59 @@ def get_available_operators(ops_dir):
return sorted(operators) return sorted(operators)
def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): def import_operator_test(test_file_path):
""" """
Run all operator test scripts in the ops directory. Import an operator test module and return the test class instance.
Args:
test_file_path: Path to the test file
Returns:
tuple: (success, test_instance_or_error)
"""
try:
# Create a unique module name
module_name = f"op_test_{test_file_path.stem}"
# Load the module from file
spec = importlib.util.spec_from_file_location(module_name, test_file_path)
if spec is None or spec.loader is None:
return False, f"Could not load module from {test_file_path}"
module = importlib.util.module_from_spec(spec)
# Add the module to sys.modules
sys.modules[module_name] = module
# Execute the module
spec.loader.exec_module(module)
# Find the test class (usually named OpTest)
test_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and hasattr(attr, "__bases__")
and any("BaseOperatorTest" in str(base) for base in attr.__bases__)
):
test_class = attr
break
if test_class is None:
return False, f"No test class found in {test_file_path}"
# Create an instance
test_instance = test_class()
return True, test_instance
except Exception as e:
return False, f"Error importing {test_file_path}: {str(e)}"
def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False):
"""
Run all operator test scripts in the ops directory using direct import.
Args: Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection. ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
...@@ -68,7 +119,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -68,7 +119,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
extra_args (list, optional): Extra command line arguments to pass to test scripts. extra_args (list, optional): Extra command line arguments to pass to test scripts.
Returns: Returns:
dict: Results dictionary with test names as keys and (success, return_code, stdout, stderr) as values. dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
""" """
if ops_dir is None: if ops_dir is None:
ops_dir = find_ops_directory() ops_dir = find_ops_directory()
...@@ -122,11 +173,6 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -122,11 +173,6 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
results = {} results = {}
# Check if verbose mode is enabled
verbose_mode = extra_args and "--verbose" in extra_args
# Check if bench mode is enabled for cumulative timing
bench_mode = extra_args and "--bench" in extra_args
cumulative_timing = { cumulative_timing = {
"total_torch_time": 0.0, "total_torch_time": 0.0,
"total_infinicore_time": 0.0, "total_infinicore_time": 0.0,
...@@ -137,117 +183,160 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -137,117 +183,160 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
test_name = test_file.stem test_name = test_file.stem
try: try:
# Run the test script - use the absolute path and run from current directory # Import and run the test directly
cmd = [sys.executable, str(test_file.absolute())] success, test_instance_or_error = import_operator_test(test_file)
if not success:
print(f"💥 {test_name}: ERROR - {test_instance_or_error}")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": test_instance_or_error,
"test_runner": None,
"stdout": "",
"stderr": test_instance_or_error,
}
continue
# Get the test runner class from the module
test_module = sys.modules[f"op_test_{test_file.stem}"]
if not hasattr(test_module, "GenericTestRunner"):
print(f"💥 {test_name}: ERROR - No GenericTestRunner found")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": "No GenericTestRunner found",
"test_runner": None,
"stdout": "",
"stderr": "No GenericTestRunner found",
}
continue
# Create and run the test runner
test_runner_class = test_module.GenericTestRunner
runner_instance = test_runner_class(test_instance_or_error.__class__)
# Temporarily redirect stdout to capture output
from io import StringIO
stdout_capture = StringIO()
stderr_capture = StringIO()
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = stdout_capture
sys.stderr = stderr_capture
try:
# Run the test
test_success, test_runner = runner_instance.run()
# Get captured output
stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
# Restore stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Print the captured output
if stdout_output:
print(stdout_output.rstrip())
if stderr_output:
print("\nSTDERR:")
print(stderr_output.rstrip())
# Analyze test results
test_results = test_runner.get_test_results() if test_runner else []
# Determine overall test status
if test_success:
return_code = 0
status_icon = "✅"
status_text = "PASSED"
else:
# Check if there are any failed tests
has_failures = any(
result.return_code == -1 for result in test_results
)
has_partial = any(
result.return_code == -3 for result in test_results
)
has_skipped = any(
result.return_code == -2 for result in test_results
)
if has_failures:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
elif has_partial:
return_code = -3
status_icon = "⚠️"
status_text = "PARTIAL"
elif has_skipped:
return_code = -2
status_icon = "⏭️"
status_text = "SKIPPED"
else:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
# Calculate timing
torch_time = sum(result.torch_time for result in test_results)
infini_time = sum(result.infini_time for result in test_results)
results[test_name] = {
"success": test_success,
"return_code": return_code,
"torch_time": torch_time,
"infini_time": infini_time,
"error_message": "",
"test_runner": test_runner,
"stdout": stdout_output,
"stderr": stderr_output,
}
# Add extra arguments if provided print(
if extra_args: f"{status_icon} {test_name}: {status_text} (return code: {return_code})"
cmd.extend(extra_args) )
result = subprocess.run(
cmd,
capture_output=True, # Capture output to analyze
text=True,
)
# Analyze output to determine test status
stdout_lower = result.stdout.lower()
stderr_lower = result.stderr.lower()
# Check for operator not implemented patterns
if (
"all tests passed!" in stdout_lower
and "success rate: 100.0%" in stdout_lower
):
success = True
returncode = 0
elif "both operators not implemented" in stdout_lower:
# Both operators not implemented - skipped test
success = False # Not a failure, but skipped
returncode = -2 # Special code for skipped
elif "operator not implemented" in stdout_lower:
# One operator not implemented - partial test
success = False # Not fully successful
returncode = -3 # Special code for partial
else:
success = False
returncode = -1
results[test_name] = (
success,
returncode,
result.stdout,
result.stderr,
)
# Print the output from the test script
print(f"\n{'='*60}")
print(f"TEST: {test_name}")
print(f"{'='*60}")
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:")
print(result.stderr.rstrip())
# Enhanced status display
if returncode == -2:
status_icon = "⏭️"
status_text = "SKIPPED"
elif returncode == -3:
status_icon = "⚠️"
status_text = "PARTIAL"
elif success:
status_icon = "✅"
status_text = "PASSED"
else:
status_icon = "❌"
status_text = "FAILED"
print( # Extract benchmark timing if in bench mode
f"{status_icon} {test_name}: {status_text} (return code: {returncode})" if bench and test_success and return_code == 0:
) cumulative_timing["total_torch_time"] += torch_time
cumulative_timing["total_infinicore_time"] += infini_time
cumulative_timing["operators_tested"] += 1
# Extract benchmark timing if in bench mode except Exception as e:
if bench_mode and success: # Restore stdout/stderr in case of exception
# Look for benchmark summary in stdout sys.stdout = old_stdout
lines = result.stdout.split("\n") sys.stderr = old_stderr
torch_time = 0.0 raise e
infini_time = 0.0
for line in lines:
if "PyTorch Total Time:" in line:
try:
# Extract time value (e.g., "PyTorch Total Time: 123.456 ms")
torch_time = (
float(line.split(":")[1].strip().split()[0]) / 1000.0
) # Convert to seconds
except:
pass
elif "InfiniCore Total Time:" in line:
try:
infini_time = (
float(line.split(":")[1].strip().split()[0]) / 1000.0
) # Convert to seconds
except:
pass
cumulative_timing["total_torch_time"] += torch_time
cumulative_timing["total_infinicore_time"] += infini_time
cumulative_timing["operators_tested"] += 1
# In verbose mode, stop execution on first failure # In verbose mode, stop execution on first failure
if verbose_mode and not success and returncode not in [-2, -3]: if verbose and not test_success and return_code != 0:
break break
except Exception as e: except Exception as e:
print(f"💥 {test_name}: ERROR - {str(e)}") print(f"💥 {test_name}: ERROR - {str(e)}")
results[test_name] = (False, -1, "", str(e)) results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": str(e),
"test_runner": None,
"stdout": "",
"stderr": str(e),
}
# In verbose mode, stop execution on any exception # In verbose mode, stop execution on any exception
if verbose_mode: if verbose:
print(f"\n{'!'*60}") print(f"\n{'!'*60}")
print( print(
f"VERBOSE MODE: Stopping execution due to exception in {test_name}" f"VERBOSE MODE: Stopping execution due to exception in {test_name}"
...@@ -259,7 +348,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -259,7 +348,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
def print_summary( def print_summary(
results, verbose_mode=False, total_expected_tests=0, cumulative_timing=None results, verbose=False, total_expected_tests=0, cumulative_timing=None
): ):
"""Print a comprehensive summary of test results including benchmark data.""" """Print a comprehensive summary of test results including benchmark data."""
print(f"\n{'='*80}") print(f"\n{'='*80}")
...@@ -280,14 +369,15 @@ def print_summary( ...@@ -280,14 +369,15 @@ def print_summary(
skipped_operators = [] # Store skipped operator names skipped_operators = [] # Store skipped operator names
partial_operators = [] # Store partial operator names partial_operators = [] # Store partial operator names
for test_name, (success, returncode, stdout, stderr) in results.items(): for test_name, result_data in results.items():
if success: return_code = result_data["return_code"]
if return_code == 0:
passed += 1 passed += 1
passed_operators.append(test_name) passed_operators.append(test_name)
elif returncode == -2: # Special code for skipped tests elif return_code == -2: # Special code for skipped tests
skipped += 1 skipped += 1
skipped_operators.append(test_name) skipped_operators.append(test_name)
elif returncode == -3: # Special code for partial tests elif return_code == -3: # Special code for partial tests
partial += 1 partial += 1
partial_operators.append(test_name) partial_operators.append(test_name)
else: else:
...@@ -316,10 +406,10 @@ def print_summary( ...@@ -316,10 +406,10 @@ def print_summary(
print("BENCHMARK SUMMARY:") print("BENCHMARK SUMMARY:")
print(f" Operators Tested: {cumulative_timing['operators_tested']}") print(f" Operators Tested: {cumulative_timing['operators_tested']}")
print( print(
f" Total PyTorch Time: {cumulative_timing['total_torch_time'] * 1000:.3f} ms" f" PyTorch Total Time: {cumulative_timing['total_torch_time'] * 1000:12.3f} ms"
) )
print( print(
f" Total InfiniCore Time: {cumulative_timing['total_infinicore_time'] * 1000:.3f} ms" f" InfiniCore Total Time: {cumulative_timing['total_infinicore_time'] * 1000:12.3f} ms"
) )
print(f"{'-'*40}") print(f"{'-'*40}")
...@@ -361,7 +451,7 @@ def print_summary( ...@@ -361,7 +451,7 @@ def print_summary(
success_rate = passed / executed_tests * 100 success_rate = passed / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%") print(f"\nSuccess rate: {success_rate:.1f}%")
if verbose_mode and total < total_expected_tests: if verbose and total < total_expected_tests:
print(f"\n💡 Verbose mode: Execution stopped after first failure") print(f"\n💡 Verbose mode: Execution stopped after first failure")
print(f" {total_expected_tests - total} tests were not executed") print(f" {total_expected_tests - total} tests were not executed")
...@@ -505,17 +595,18 @@ def main(): ...@@ -505,17 +595,18 @@ def main():
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
action="store_true", action="store_true",
help="Enable verbose mode to stop on first error with full traceback (passed to individual tests)", help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument(
"--bench",
action="store_true",
help="Enable bench mode to show performance data",
) )
from framework import get_hardware_args_group get_hardware_args_group(parser)
if "-h" in sys.argv or "--help" in sys.argv:
get_hardware_args_group(parser)
# Parse known args first, leave the rest for the test scripts # Parse known args first, leave the rest for the test scripts
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
get_hardware_args_group(parser)
# Handle list command # Handle list command
if args.list: if args.list:
...@@ -536,10 +627,6 @@ def main(): ...@@ -536,10 +627,6 @@ def main():
print(f"Error: Ops directory '{ops_dir}' does not exist.") print(f"Error: Ops directory '{ops_dir}' does not exist.")
sys.exit(1) sys.exit(1)
# Add verbose flag to extra arguments if specified
if args.verbose and "--verbose" not in unknown_args:
unknown_args.append("--verbose")
# Show what extra arguments will be passed # Show what extra arguments will be passed
if unknown_args: if unknown_args:
print(f"Passing extra arguments to test scripts: {unknown_args}") print(f"Passing extra arguments to test scripts: {unknown_args}")
...@@ -584,7 +671,8 @@ def main(): ...@@ -584,7 +671,8 @@ def main():
results, cumulative_timing = run_all_op_tests( results, cumulative_timing = run_all_op_tests(
ops_dir=ops_dir, ops_dir=ops_dir,
specific_ops=args.ops, specific_ops=args.ops,
extra_args=unknown_args, bench=args.bench,
verbose=args.verbose,
) )
# Print summary and exit with appropriate code # Print summary and exit with appropriate code
...@@ -594,7 +682,7 @@ def main(): ...@@ -594,7 +682,7 @@ def main():
# Check if there were any tests with missing implementations # Check if there were any tests with missing implementations
has_missing_implementations = any( has_missing_implementations = any(
returncode in [-2, -3] for _, (_, returncode, _, _) in results.items() result_data["return_code"] in [-2, -3] for result_data in results.values()
) )
if all_passed and has_missing_implementations: if all_passed and has_missing_implementations:
...@@ -607,8 +695,8 @@ def main(): ...@@ -607,8 +695,8 @@ def main():
) )
failed_ops = [ failed_ops = [
name name
for name, (success, _, _, _) in results.items() for name, result_data in results.items()
if not success and name in results if result_data["return_code"] == -1
] ]
for op in failed_ops[:3]: # Show first 3 failed operators for op in failed_ops[:3]: # Show first 3 failed operators
print(f" python {ops_dir / (op + '.py')} --verbose") print(f" python {ops_dir / (op + '.py')} --verbose")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment