Unverified Commit b7d9252b authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #595 from InfiniTensor/issue/593

Issue/593 & Issue/594 & Issue/598
parents 9b0b89c5 1e6ccdc9
...@@ -182,9 +182,9 @@ pip install . -e ...@@ -182,9 +182,9 @@ pip install . -e
```bash ```bash
# 测试单算子 # 测试单算子
python test/infinicore/ops/[operator].py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon] python test/infinicore/ops/[operator].py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon]
# 测试全部算子 # 测试全部算子
python test/infinicore/run.py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun] python test/infinicore/run.py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
``` ```
使用 -h 查看更多参数。 使用 -h 查看更多参数。
......
import torch import torch
import infinicore import infinicore
import traceback
from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Tuple
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
...@@ -11,11 +12,21 @@ from .utils import ( ...@@ -11,11 +12,21 @@ from .utils import (
create_test_comparator, create_test_comparator,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
profile_operation, profile_operation,
synchronize_device,
convert_infinicore_to_torch,
) )
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_time: float = 0.0
infini_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase: class TestCase:
"""Test case with all configuration included""" """Test case with all configuration included"""
...@@ -24,11 +35,11 @@ class TestCase: ...@@ -24,11 +35,11 @@ class TestCase:
inputs, inputs,
kwargs=None, kwargs=None,
output_spec=None, output_spec=None,
output_specs=None,
comparison_target=None, comparison_target=None,
description="", description="",
tolerance=None, tolerance=None,
output_count=1, output_count=1,
output_specs=None,
): ):
""" """
Initialize a test case with complete configuration Initialize a test case with complete configuration
...@@ -216,14 +227,19 @@ class TestCase: ...@@ -216,14 +227,19 @@ class TestCase:
class TestConfig: class TestConfig:
"""Test configuration""" """Test configuration"""
def __init__(self, debug=False, bench=False, num_prerun=10, num_iterations=1000): def __init__(
self,
debug=False,
bench=False,
num_prerun=10,
num_iterations=1000,
verbose=False,
):
self.debug = debug self.debug = debug
self.bench = bench self.bench = bench
self.num_prerun = num_prerun self.num_prerun = num_prerun
self.num_iterations = num_iterations self.num_iterations = num_iterations
self.verbose = verbose
# In base.py - update the TestRunner class
class TestRunner: class TestRunner:
...@@ -238,6 +254,14 @@ class TestRunner: ...@@ -238,6 +254,14 @@ class TestRunner:
self.passed_tests = ( self.passed_tests = (
[] []
) # Track passed tests (both operators implemented and passed) ) # Track passed tests (both operators implemented and passed)
# Add benchmark timing statistics
self.benchmark_times = {
"torch_total": 0.0,
"infinicore_total": 0.0,
"per_test_case": {}, # Store timing per test case
}
# Store test results
self.test_results = []
def run_tests(self, devices, test_func, test_type="Test"): def run_tests(self, devices, test_func, test_type="Test"):
""" """
...@@ -260,30 +284,30 @@ class TestRunner: ...@@ -260,30 +284,30 @@ class TestRunner:
try: try:
print(f"{test_case}") print(f"{test_case}")
# Execute test and get result status # Execute test and get TestResult object
success, status = test_func(device, test_case, self.config) test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
# Handle different test statuses # Handle different test statuses based on return_code
if status == "passed": if test_result.return_code == 0: # Success
self.passed_tests.append( self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}" f"{test_case} - {InfiniDeviceNames[device]}"
) )
print(f"\033[92m✓\033[0m Passed") print(f"\033[92m✓\033[0m Passed")
elif status == "skipped": elif test_result.return_code == -1:
# Test was skipped due to both operators not being implemented fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - Test terminated in verbose mode."
self.failed_tests.append(fail_msg)
elif test_result.return_code == -2: # Skipped
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented" skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
self.skipped_tests.append(skip_msg) self.skipped_tests.append(skip_msg)
print( print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
f"\033[93m⚠\033[0m Skipped - both operators not implemented" elif test_result.return_code == -3: # Partial
)
elif status == "partial":
# Test was partially executed (one operator not implemented)
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented" partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
self.partial_tests.append(partial_msg) self.partial_tests.append(partial_msg)
print( print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
f"\033[93m⚠\033[0m Partial - one operator not implemented"
) if self.config.verbose and test_result.return_code != 0:
# Failed tests are handled in the exception handler below return False
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
...@@ -291,11 +315,25 @@ class TestRunner: ...@@ -291,11 +315,25 @@ class TestRunner:
) )
print(f"\033[91m✗\033[0m {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
success=False,
return_code=-1,
error_message=str(e),
test_case=test_case,
device=device
)
self.test_results.append(failed_result)
# In verbose mode, print full traceback and stop execution
if self.config.verbose:
traceback.print_exc()
return False # Stop test execution immediately
if self.config.debug: if self.config.debug:
raise raise
# Return True if no tests failed (skipped/partial tests don't count as failures) return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
return len(self.failed_tests) == 0
def print_summary(self): def print_summary(self):
""" """
...@@ -312,34 +350,16 @@ class TestRunner: ...@@ -312,34 +350,16 @@ class TestRunner:
print(f"\n{'='*60}") print(f"\n{'='*60}")
print("TEST SUMMARY") print("TEST SUMMARY")
print(f"{'='*60}")
print(f"Total tests: {total_tests}") print(f"Total tests: {total_tests}")
print(f"\033[92mPassed: {passed_count}\033[0m") print(f"\033[92mPassed: {passed_count}\033[0m")
# Display partial tests (one operator not implemented) result = True
if self.partial_tests:
print(
f"\033[93mPartial (one operator not implemented): {partial_count}\033[0m"
)
for test in self.partial_tests:
print(f" - {test}")
# Display skipped tests (both operators not implemented)
if self.skipped_tests:
print(
f"\033[93mSkipped (both operators not implemented): {skipped_count}\033[0m"
)
for test in self.skipped_tests:
print(f" - {test}")
# Display failed tests # Display failed tests
if self.failed_tests: if self.failed_tests:
print(f"\033[91mFailed: {failed_count}\033[0m") print(f"\033[91mFailed: {failed_count}\033[0m")
for failure in self.failed_tests:
print(f" - {failure}")
# Return False only if there are actual test failures # Return False only if there are actual test failures
return False result = False
else: else:
# Calculate success rate based on actual executed tests # Calculate success rate based on actual executed tests
executed_tests = passed_count + partial_count + failed_count executed_tests = passed_count + partial_count + failed_count
...@@ -352,10 +372,41 @@ class TestRunner: ...@@ -352,10 +372,41 @@ class TestRunner:
print( print(
f"\n\033[93mTests completed with some implementations missing\033[0m" f"\n\033[93mTests completed with some implementations missing\033[0m"
) )
return True # Skipped/partial tests don't count as failures
else: else:
print(f"\n\033[92mAll tests passed!\033[0m") print(f"\n\033[92mAll tests passed!\033[0m")
return True
# Print benchmark summary if benchmarking was enabled
if self.config.bench and (
self.benchmark_times["torch_total"] > 0
or self.benchmark_times["infinicore_total"] > 0
):
self._print_benchmark_summary()
print(f"{'='*60}")
return result
def _print_benchmark_summary(self):
"""Print benchmark timing summary"""
print(f"{'-'*60}")
print("BENCHMARK SUMMARY")
torch_total = self.benchmark_times["torch_total"]
infinicore_total = self.benchmark_times["infinicore_total"]
if torch_total > 0:
print(f"PyTorch Total Time: {torch_total * 1000:.3f} ms")
if infinicore_total > 0:
print(f"InfiniCore Total Time: {infinicore_total * 1000:.3f} ms")
if torch_total > 0 and infinicore_total > 0:
speedup = (
torch_total / infinicore_total if infinicore_total > 0 else float("inf")
)
print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x")
def get_test_results(self):
"""Get all test results"""
return self.test_results
class BaseOperatorTest(ABC): class BaseOperatorTest(ABC):
...@@ -460,11 +511,17 @@ class BaseOperatorTest(ABC): ...@@ -460,11 +511,17 @@ class BaseOperatorTest(ABC):
config: Test configuration config: Test configuration
Returns: Returns:
tuple: (success, status) where: TestResult: Test result object containing status and timing information
success: bool indicating if test passed
status: str describing test status ("passed", "skipped", "partial")
""" """
device_str = torch_device_map[device] device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
device=device
)
# Prepare inputs and kwargs with actual tensors # Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
...@@ -537,6 +594,12 @@ class BaseOperatorTest(ABC): ...@@ -537,6 +594,12 @@ class BaseOperatorTest(ABC):
if torch_result is None: if torch_result is None:
torch_implemented = False torch_implemented = False
except NotImplementedError: except NotImplementedError:
if config.verbose:
traceback.print_exc()
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "torch_operator not implemented"
return test_result
torch_implemented = False torch_implemented = False
torch_result = None torch_result = None
...@@ -545,25 +608,26 @@ class BaseOperatorTest(ABC): ...@@ -545,25 +608,26 @@ class BaseOperatorTest(ABC):
if infini_result is None: if infini_result is None:
infini_implemented = False infini_implemented = False
except NotImplementedError: except NotImplementedError:
if config.verbose:
traceback.print_exc()
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "infinicore_operator not implemented"
return test_result
infini_implemented = False infini_implemented = False
infini_result = None infini_result = None
# Skip if neither operator is implemented # Skip if neither operator is implemented
if not torch_implemented and not infini_implemented: if not torch_implemented and not infini_implemented:
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped") test_result.return_code = -2 # Skipped
return False, "skipped" return test_result
# Single operator execution without comparison # Single operator execution without comparison
if not torch_implemented or not infini_implemented: if not torch_implemented or not infini_implemented:
missing_op = ( test_result.return_code = -3 # Partial
"torch_operator" if not torch_implemented else "infinicore_operator" # Run benchmarking for partial tests if enabled
)
print(
f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
)
if config.bench: if config.bench:
self._run_benchmarking( torch_time, infini_time = self._run_benchmarking(
config, config,
device_str, device_str,
torch_implemented, torch_implemented,
...@@ -575,8 +639,9 @@ class BaseOperatorTest(ABC): ...@@ -575,8 +639,9 @@ class BaseOperatorTest(ABC):
test_case.output_count, test_case.output_count,
comparison_target, comparison_target,
) )
return False, "partial" test_result.torch_time = torch_time
test_result.infini_time = infini_time
return test_result
# ========================================================================== # ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC # MULTIPLE OUTPUTS COMPARISON LOGIC
# ========================================================================== # ==========================================================================
...@@ -685,7 +750,7 @@ class BaseOperatorTest(ABC): ...@@ -685,7 +750,7 @@ class BaseOperatorTest(ABC):
# UNIFIED BENCHMARKING LOGIC # UNIFIED BENCHMARKING LOGIC
# ========================================================================== # ==========================================================================
if config.bench: if config.bench:
self._run_benchmarking( torch_time, infini_time = self._run_benchmarking(
config, config,
device_str, device_str,
True, True,
...@@ -697,9 +762,13 @@ class BaseOperatorTest(ABC): ...@@ -697,9 +762,13 @@ class BaseOperatorTest(ABC):
test_case.output_count, test_case.output_count,
comparison_target, comparison_target,
) )
test_result.torch_time = torch_time
test_result.infini_time = infini_time
# Test passed successfully # Test passed successfully
return True, "passed" test_result.success = True
test_result.return_code = 0
return test_result
def _run_benchmarking( def _run_benchmarking(
self, self,
...@@ -715,8 +784,15 @@ class BaseOperatorTest(ABC): ...@@ -715,8 +784,15 @@ class BaseOperatorTest(ABC):
comparison_target, comparison_target,
): ):
""" """
Unified benchmarking logic Unified benchmarking logic with timing accumulation
Returns:
tuple: (torch_time, infini_time) timing results
""" """
# Initialize timing variables
torch_time = 0.0
infini_time = 0.0
if torch_implemented: if torch_implemented:
if output_count > 1: if output_count > 1:
# For multiple outputs, just call the operator # For multiple outputs, just call the operator
...@@ -739,12 +815,13 @@ class BaseOperatorTest(ABC): ...@@ -739,12 +815,13 @@ class BaseOperatorTest(ABC):
else inputs[comparison_target] else inputs[comparison_target]
) )
profile_operation( torch_time = profile_operation(
"PyTorch ", "PyTorch ",
torch_op, torch_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
config.num_iterations, config.num_iterations,
total=True,
) )
if infini_implemented: if infini_implemented:
...@@ -763,10 +840,19 @@ class BaseOperatorTest(ABC): ...@@ -763,10 +840,19 @@ class BaseOperatorTest(ABC):
else infini_inputs[comparison_target] else infini_inputs[comparison_target]
) )
profile_operation( infini_time = profile_operation(
"InfiniCore", "InfiniCore",
infini_op, infini_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
config.num_iterations, config.num_iterations,
total=True,
) )
# Store timing information in the test runner
if hasattr(config, "_test_runner") and config._test_runner:
# Accumulate total times
config._test_runner.benchmark_times["torch_total"] += torch_time
config._test_runner.benchmark_times["infinicore_total"] += infini_time
return torch_time, infini_time
import argparse import argparse
from .devices import InfiniDeviceEnum from .devices import InfiniDeviceEnum
# hardware_info.py
""" """
Shared hardware platform information for the InfiniCore testing framework Shared hardware platform information for the InfiniCore testing framework
""" """
...@@ -61,6 +60,9 @@ Examples: ...@@ -61,6 +60,9 @@ Examples:
# Run with debug mode on multiple devices # Run with debug mode on multiple devices
python test_operator.py --cpu --nvidia --debug python test_operator.py --cpu --nvidia --debug
# Run with verbose mode to stop on first error with full traceback
python test_operator.py --cpu --nvidia --verbose
# Run performance profiling with custom iterations # Run performance profiling with custom iterations
python test_operator.py --nvidia --bench --num_prerun 50 --num_iterations 5000 python test_operator.py --nvidia --bench --num_prerun 50 --num_iterations 5000
...@@ -90,11 +92,17 @@ Examples: ...@@ -90,11 +92,17 @@ Examples:
action="store_true", action="store_true",
help="Enable debug mode for detailed tensor comparison", help="Enable debug mode for detailed tensor comparison",
) )
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose mode to stop on first error with full traceback",
)
# Device options using shared hardware info # Device options using shared hardware info
hardware_group = get_hardware_args_group(parser) hardware_group = get_hardware_args_group(parser)
args, unknown = parser.parse_known_args()
return parser.parse_args() return args
def get_test_devices(args): def get_test_devices(args):
......
...@@ -21,16 +21,23 @@ class GenericTestRunner: ...@@ -21,16 +21,23 @@ class GenericTestRunner:
"""Execute the complete test suite """Execute the complete test suite
Returns: Returns:
bool: True if all tests passed or were skipped/partial, False if any tests failed tuple: (success, test_runner) where:
success: bool indicating if all tests passed or were skipped/partial
test_runner: TestRunner instance with test results
""" """
config = TestConfig( config = TestConfig(
debug=self.args.debug, debug=self.args.debug,
bench=self.args.bench, bench=self.args.bench,
num_prerun=self.args.num_prerun, num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations, num_iterations=self.args.num_iterations,
verbose=self.args.verbose, # Pass verbose flag to TestConfig
) )
runner = TestRunner(self.operator_test.test_cases, config) runner = TestRunner(self.operator_test.test_cases, config)
# Pass the test runner instance to config for benchmark timing accumulation
config._test_runner = runner
devices = get_test_devices(self.args) devices = get_test_devices(self.args)
# Run unified tests - returns True if no tests failed # Run unified tests - returns True if no tests failed
...@@ -46,7 +53,7 @@ class GenericTestRunner: ...@@ -46,7 +53,7 @@ class GenericTestRunner:
# Both conditions must be True for overall success # Both conditions must be True for overall success
# - has_no_failures: no test failures during execution # - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures # - summary_passed: summary confirms no failures
return has_no_failures and summary_passed return (has_no_failures and summary_passed), runner
def run_and_exit(self): def run_and_exit(self):
"""Run tests and exit with appropriate status code """Run tests and exit with appropriate status code
...@@ -55,5 +62,5 @@ class GenericTestRunner: ...@@ -55,5 +62,5 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures) 0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed 1: One or more tests failed
""" """
success = self.run() success, runner = self.run()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)
...@@ -22,10 +22,12 @@ def timed_op(func, num_iterations, device): ...@@ -22,10 +22,12 @@ def timed_op(func, num_iterations, device):
for _ in range(num_iterations): for _ in range(num_iterations):
func() func()
synchronize_device(device) synchronize_device(device)
return (time.time() - start) / num_iterations return time.time() - start
def profile_operation(desc, func, torch_device, num_prerun, num_iterations): def profile_operation(
desc, func, torch_device, num_prerun, num_iterations, total=False
):
""" """
Performance profiling workflow Performance profiling workflow
""" """
...@@ -35,7 +37,11 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations): ...@@ -35,7 +37,11 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
# Timed execution # Timed execution
elapsed = timed_op(lambda: func(), num_iterations, torch_device) elapsed = timed_op(lambda: func(), num_iterations, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms") print(f" {desc} time: {elapsed / num_iterations * 1000 :6f} ms")
if total:
return elapsed
else:
return elapsed / num_iterations
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
......
...@@ -133,9 +133,9 @@ class OpTest(BaseOperatorTest): ...@@ -133,9 +133,9 @@ class OpTest(BaseOperatorTest):
"""PyTorch ELU implementation""" """PyTorch ELU implementation"""
return torch.nn.functional.elu(*args, **kwargs) return torch.nn.functional.elu(*args, **kwargs)
def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs): # def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
"""InfiniCore ELU implementation""" # """InfiniCore ELU implementation"""
return None # return None
def main(): def main():
......
...@@ -103,7 +103,7 @@ def parse_test_cases(): ...@@ -103,7 +103,7 @@ def parse_test_cases():
return test_cases return test_cases
class MultiMarginLossOpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""MultiMarginLoss operator test with device handling""" """MultiMarginLoss operator test with device handling"""
def __init__(self): def __init__(self):
...@@ -116,9 +116,9 @@ class MultiMarginLossOpTest(BaseOperatorTest): ...@@ -116,9 +116,9 @@ class MultiMarginLossOpTest(BaseOperatorTest):
"""PyTorch multi_margin_loss implementation with device handling""" """PyTorch multi_margin_loss implementation with device handling"""
return F.multi_margin_loss(*args, **kwargs) return F.multi_margin_loss(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs): # def infinicore_operator(self, *args, **kwargs):
"""InfiniCore multi_margin_loss implementation""" # """InfiniCore multi_margin_loss implementation"""
return None # return None
def main(): def main():
......
import os import os
import sys import sys
import subprocess
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple, List import importlib.util
from framework import get_hardware_args_group
def find_ops_directory(location=None): def find_ops_directory(location=None):
...@@ -58,9 +59,59 @@ def get_available_operators(ops_dir): ...@@ -58,9 +59,59 @@ def get_available_operators(ops_dir):
return sorted(operators) return sorted(operators)
def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): def import_operator_test(test_file_path):
"""
Import an operator test module and return the test class instance.
Args:
test_file_path: Path to the test file
Returns:
tuple: (success, test_instance_or_error)
"""
try:
# Create a unique module name
module_name = f"op_test_{test_file_path.stem}"
# Load the module from file
spec = importlib.util.spec_from_file_location(module_name, test_file_path)
if spec is None or spec.loader is None:
return False, f"Could not load module from {test_file_path}"
module = importlib.util.module_from_spec(spec)
# Add the module to sys.modules
sys.modules[module_name] = module
# Execute the module
spec.loader.exec_module(module)
# Find the test class (usually named OpTest)
test_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and hasattr(attr, "__bases__")
and any("BaseOperatorTest" in str(base) for base in attr.__bases__)
):
test_class = attr
break
if test_class is None:
return False, f"No test class found in {test_file_path}"
# Create an instance
test_instance = test_class()
return True, test_instance
except Exception as e:
return False, f"Error importing {test_file_path}: {str(e)}"
def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False):
""" """
Run all operator test scripts in the ops directory. Run all operator test scripts in the ops directory using direct import.
Args: Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection. ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
...@@ -68,7 +119,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -68,7 +119,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
extra_args (list, optional): Extra command line arguments to pass to test scripts. extra_args (list, optional): Extra command line arguments to pass to test scripts.
Returns: Returns:
dict: Results dictionary with test names as keys and (success, return_code, stdout, stderr) as values. dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
""" """
if ops_dir is None: if ops_dir is None:
ops_dir = find_ops_directory() ops_dir = find_ops_directory()
...@@ -122,92 +173,184 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -122,92 +173,184 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
results = {} results = {}
cumulative_timing = {
"total_torch_time": 0.0,
"total_infinicore_time": 0.0,
"operators_tested": 0,
}
for test_file in operator_test_files: for test_file in operator_test_files:
test_name = test_file.stem test_name = test_file.stem
try: try:
# Run the test script - use the absolute path and run from current directory # Import and run the test directly
cmd = [sys.executable, str(test_file.absolute())] success, test_instance_or_error = import_operator_test(test_file)
# Add extra arguments if provided if not success:
if extra_args: print(f"💥 {test_name}: ERROR - {test_instance_or_error}")
cmd.extend(extra_args) results[test_name] = {
"success": False,
result = subprocess.run( "return_code": -1,
cmd, "torch_time": 0.0,
capture_output=True, # Capture output to analyze "infini_time": 0.0,
text=True, "error_message": test_instance_or_error,
) "test_runner": None,
"stdout": "",
# Analyze output to determine test status "stderr": test_instance_or_error,
stdout_lower = result.stdout.lower() }
stderr_lower = result.stderr.lower() continue
# Check for operator not implemented patterns # Get the test runner class from the module
if ( test_module = sys.modules[f"op_test_{test_file.stem}"]
"all tests passed!" in stdout_lower if not hasattr(test_module, "GenericTestRunner"):
and "success rate: 100.0%" in stdout_lower print(f"💥 {test_name}: ERROR - No GenericTestRunner found")
): results[test_name] = {
success = True "success": False,
returncode = 0 "return_code": -1,
elif "both operators not implemented" in stdout_lower: "torch_time": 0.0,
# Both operators not implemented - skipped test "infini_time": 0.0,
success = False # Not a failure, but skipped "error_message": "No GenericTestRunner found",
returncode = -2 # Special code for skipped "test_runner": None,
elif "one operator not implemented" in stdout_lower: "stdout": "",
# One operator not implemented - partial test "stderr": "No GenericTestRunner found",
success = False # Not fully successful }
returncode = -3 # Special code for partial continue
else:
success = False # Create and run the test runner
returncode = -1 test_runner_class = test_module.GenericTestRunner
runner_instance = test_runner_class(test_instance_or_error.__class__)
results[test_name] = (
success, # Temporarily redirect stdout to capture output
returncode, from io import StringIO
result.stdout,
result.stderr, stdout_capture = StringIO()
) stderr_capture = StringIO()
# Print the output from the test script old_stdout = sys.stdout
print(f"\n{'='*60}") old_stderr = sys.stderr
print(f"TEST: {test_name}") sys.stdout = stdout_capture
print(f"{'='*60}") sys.stderr = stderr_capture
if result.stdout: try:
print(result.stdout.rstrip()) # Run the test
test_success, test_runner = runner_instance.run()
if result.stderr:
print("\nSTDERR:") # Get captured output
print(result.stderr.rstrip()) stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
# Enhanced status display
if returncode == -2: # Restore stdout/stderr
status_icon = "⏭️" sys.stdout = old_stdout
status_text = "SKIPPED" sys.stderr = old_stderr
elif returncode == -3:
status_icon = "⚠️" # Print the captured output
status_text = "PARTIAL" if stdout_output:
elif success: print(stdout_output.rstrip())
status_icon = "✅" if stderr_output:
status_text = "PASSED" print("\nSTDERR:")
else: print(stderr_output.rstrip())
status_icon = "❌"
status_text = "FAILED" # Analyze test results
test_results = test_runner.get_test_results() if test_runner else []
print(
f"{status_icon} {test_name}: {status_text} (return code: {returncode})" # Determine overall test status
) if test_success:
return_code = 0
status_icon = "✅"
status_text = "PASSED"
else:
# Check if there are any failed tests
has_failures = any(
result.return_code == -1 for result in test_results
)
has_partial = any(
result.return_code == -3 for result in test_results
)
has_skipped = any(
result.return_code == -2 for result in test_results
)
if has_failures:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
elif has_partial:
return_code = -3
status_icon = "⚠️"
status_text = "PARTIAL"
elif has_skipped:
return_code = -2
status_icon = "⏭️"
status_text = "SKIPPED"
else:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
# Calculate timing
torch_time = sum(result.torch_time for result in test_results)
infini_time = sum(result.infini_time for result in test_results)
results[test_name] = {
"success": test_success,
"return_code": return_code,
"torch_time": torch_time,
"infini_time": infini_time,
"error_message": "",
"test_runner": test_runner,
"stdout": stdout_output,
"stderr": stderr_output,
}
print(
f"{status_icon} {test_name}: {status_text} (return code: {return_code})"
)
# Extract benchmark timing if in bench mode
if bench and test_success and return_code == 0:
cumulative_timing["total_torch_time"] += torch_time
cumulative_timing["total_infinicore_time"] += infini_time
cumulative_timing["operators_tested"] += 1
except Exception as e:
# Restore stdout/stderr in case of exception
sys.stdout = old_stdout
sys.stderr = old_stderr
raise e
# In verbose mode, stop execution on first failure
if verbose and not test_success and return_code != 0:
break
except Exception as e: except Exception as e:
print(f"💥 {test_name}: ERROR - {str(e)}") print(f"💥 {test_name}: ERROR - {str(e)}")
results[test_name] = (False, -1, "", str(e)) results[test_name] = {
"success": False,
return results "return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
def print_summary(results): "error_message": str(e),
"""Print a comprehensive summary of test results.""" "test_runner": None,
"stdout": "",
"stderr": str(e),
}
# In verbose mode, stop execution on any exception
if verbose:
print(f"\n{'!'*60}")
print(
f"VERBOSE MODE: Stopping execution due to exception in {test_name}"
)
print(f"{'!'*60}")
break
return results, cumulative_timing
def print_summary(
results, verbose=False, total_expected_tests=0, cumulative_timing=None
):
"""Print a comprehensive summary of test results including benchmark data."""
print(f"\n{'='*80}") print(f"\n{'='*80}")
print("CUMULATIVE TEST SUMMARY") print("CUMULATIVE TEST SUMMARY")
print(f"{'='*80}") print(f"{'='*80}")
...@@ -226,14 +369,15 @@ def print_summary(results): ...@@ -226,14 +369,15 @@ def print_summary(results):
skipped_operators = [] # Store skipped operator names skipped_operators = [] # Store skipped operator names
partial_operators = [] # Store partial operator names partial_operators = [] # Store partial operator names
for test_name, (success, returncode, stdout, stderr) in results.items(): for test_name, result_data in results.items():
if success: return_code = result_data["return_code"]
if return_code == 0:
passed += 1 passed += 1
passed_operators.append(test_name) passed_operators.append(test_name)
elif returncode == -2: # Special code for skipped tests elif return_code == -2: # Special code for skipped tests
skipped += 1 skipped += 1
skipped_operators.append(test_name) skipped_operators.append(test_name)
elif returncode == -3: # Special code for partial tests elif return_code == -3: # Special code for partial tests
partial += 1 partial += 1
partial_operators.append(test_name) partial_operators.append(test_name)
else: else:
...@@ -242,7 +386,11 @@ def print_summary(results): ...@@ -242,7 +386,11 @@ def print_summary(results):
total = len(results) total = len(results)
print(f"Total tests: {total}") print(f"Total tests run: {total}")
if total_expected_tests > 0 and total < total_expected_tests:
print(f"Total tests expected: {total_expected_tests}")
print(f"Tests not executed: {total_expected_tests - total}")
print(f"Passed: {passed}") print(f"Passed: {passed}")
print(f"Failed: {failed}") print(f"Failed: {failed}")
...@@ -252,6 +400,19 @@ def print_summary(results): ...@@ -252,6 +400,19 @@ def print_summary(results):
if partial > 0: if partial > 0:
print(f"Partial: {partial}") print(f"Partial: {partial}")
# Print benchmark summary if cumulative_timing data is available
if cumulative_timing and cumulative_timing["operators_tested"] > 0:
print(f"{'-'*40}")
print("BENCHMARK SUMMARY:")
print(f" Operators Tested: {cumulative_timing['operators_tested']}")
print(
f" PyTorch Total Time: {cumulative_timing['total_torch_time'] * 1000:12.3f} ms"
)
print(
f" InfiniCore Total Time: {cumulative_timing['total_infinicore_time'] * 1000:12.3f} ms"
)
print(f"{'-'*40}")
# Display passed operators # Display passed operators
if passed_operators: if passed_operators:
print(f"\n✅ PASSED OPERATORS ({len(passed_operators)}):") print(f"\n✅ PASSED OPERATORS ({len(passed_operators)}):")
...@@ -284,12 +445,16 @@ def print_summary(results): ...@@ -284,12 +445,16 @@ def print_summary(results):
print(" " + ", ".join(line_ops)) print(" " + ", ".join(line_ops))
if total > 0: if total > 0:
# Calculate success rate based on executed tests only # Calculate success rate based on actual executed tests
executed_tests = passed + failed + partial executed_tests = passed + failed + partial
if executed_tests > 0: if executed_tests > 0:
success_rate = passed / executed_tests * 100 success_rate = passed / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%") print(f"\nSuccess rate: {success_rate:.1f}%")
if verbose and total < total_expected_tests:
print(f"\n💡 Verbose mode: Execution stopped after first failure")
print(f" {total_expected_tests - total} tests were not executed")
if failed == 0: if failed == 0:
if skipped > 0 or partial > 0: if skipped > 0 or partial > 0:
print(f"\n⚠️ Tests completed with some operators not implemented") print(f"\n⚠️ Tests completed with some operators not implemented")
...@@ -358,6 +523,14 @@ def generate_help_epilog(ops_dir): ...@@ -358,6 +523,14 @@ def generate_help_epilog(ops_dir):
epilog_parts.append(" # Run with debug mode on multiple devices") epilog_parts.append(" # Run with debug mode on multiple devices")
epilog_parts.append(" python run.py --cpu --nvidia --debug") epilog_parts.append(" python run.py --cpu --nvidia --debug")
epilog_parts.append("") epilog_parts.append("")
epilog_parts.append(
" # Run with verbose mode to stop on first error with full traceback"
)
epilog_parts.append(" python run.py --cpu --nvidia --verbose")
epilog_parts.append("")
epilog_parts.append(" # Run with benchmarking to get cumulative timing")
epilog_parts.append(" python run.py --cpu --bench")
epilog_parts.append("")
epilog_parts.append(" # List available tests without running") epilog_parts.append(" # List available tests without running")
epilog_parts.append(" python run.py --list") epilog_parts.append(" python run.py --list")
epilog_parts.append("") epilog_parts.append("")
...@@ -384,7 +557,13 @@ def generate_help_epilog(ops_dir): ...@@ -384,7 +557,13 @@ def generate_help_epilog(ops_dir):
" - Operators are automatically discovered from the ops directory" " - Operators are automatically discovered from the ops directory"
) )
epilog_parts.append( epilog_parts.append(
" - --bench option is disabled in batch mode (run individual tests for benchmarking)" " - --bench mode now shows cumulative timing across all operators"
)
epilog_parts.append(
" - --verbose mode stops execution on first error and shows full traceback"
)
epilog_parts.append(
" - In verbose mode, subsequent tests are skipped after first failure"
) )
return "\n".join(epilog_parts) return "\n".join(epilog_parts)
...@@ -413,15 +592,21 @@ def main(): ...@@ -413,15 +592,21 @@ def main():
action="store_true", action="store_true",
help="List all available test files without running them", help="List all available test files without running them",
) )
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument(
"--bench",
action="store_true",
help="Enable bench mode to show performance data",
)
from framework import get_hardware_args_group get_hardware_args_group(parser)
if "-h" in sys.argv or "--help" in sys.argv:
get_hardware_args_group(parser)
# Parse known args first, leave the rest for the test scripts # Parse known args first, leave the rest for the test scripts
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
get_hardware_args_group(parser)
# Handle list command # Handle list command
if args.list: if args.list:
...@@ -453,6 +638,9 @@ def main(): ...@@ -453,6 +638,9 @@ def main():
print(f"Operating directory: {ops_dir}") print(f"Operating directory: {ops_dir}")
print(f"Available operators: {len(available_operators)}") print(f"Available operators: {len(available_operators)}")
if args.verbose:
print(f"Verbose mode: ENABLED (will stop on first error with full traceback)")
if args.ops: if args.ops:
# Validate requested operators # Validate requested operators
valid_ops = [] valid_ops = []
...@@ -469,32 +657,50 @@ def main(): ...@@ -469,32 +657,50 @@ def main():
if valid_ops: if valid_ops:
print(f"Testing operators: {', '.join(valid_ops)}") print(f"Testing operators: {', '.join(valid_ops)}")
total_expected_tests = len(valid_ops)
else: else:
print("No valid operators specified. Running all available tests.") print("No valid operators specified. Running all available tests.")
total_expected_tests = len(available_operators)
else: else:
print("Testing all available operators") print("Testing all available operators")
total_expected_tests = len(available_operators)
print() print()
# Run all tests # Run all tests
results = run_all_op_tests( results, cumulative_timing = run_all_op_tests(
ops_dir=ops_dir, ops_dir=ops_dir,
specific_ops=args.ops, specific_ops=args.ops,
extra_args=unknown_args, bench=args.bench,
verbose=args.verbose,
) )
# Print summary and exit with appropriate code # Print summary and exit with appropriate code
all_passed = print_summary(results) all_passed = print_summary(
results, args.verbose, total_expected_tests, cumulative_timing
)
# Check if there were any tests with missing implementations # Check if there were any tests with missing implementations
has_missing_implementations = any( has_missing_implementations = any(
returncode in [-2, -3] for _, (_, returncode, _, _) in results.items() result_data["return_code"] in [-2, -3] for result_data in results.values()
) )
if all_passed and has_missing_implementations: if all_passed and has_missing_implementations:
print(f"\n⚠️ Note: Some operators are not fully implemented") print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations") print(f" Run individual tests for details on missing implementations")
if args.verbose and not all_passed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
failed_ops = [
name
for name, result_data in results.items()
if result_data["return_code"] == -1
]
for op in failed_ops[:3]: # Show first 3 failed operators
print(f" python {ops_dir / (op + '.py')} --verbose")
sys.exit(0 if all_passed else 1) sys.exit(0 if all_passed else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment