Commit 1e6ccdc9 authored by wooway777's avatar wooway777
Browse files

issue/598 - optimize run.py performance

parent 5c88cbbd
......@@ -182,9 +182,9 @@ pip install . -e
```bash
# 测试单算子
python test/infinicore/ops/[operator].py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon]
python test/infinicore/ops/[operator].py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun | --Hygon]
# 测试全部算子
python test/infinicore/run.py [--bench | --debug] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
python test/infinicore/run.py [--bench | --debug | --verbose] [--cpu | --nvidia | --cambricon | --ascend | --iluvatar | --metax | --moore | --kunlun]
```
使用 -h 查看更多参数。
......
import torch
import infinicore
import traceback
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Tuple
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
......@@ -15,6 +15,18 @@ from .utils import (
)
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_time: float = 0.0
infini_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
......@@ -23,11 +35,11 @@ class TestCase:
inputs,
kwargs=None,
output_spec=None,
output_specs=None,
comparison_target=None,
description="",
tolerance=None,
output_count=1,
output_specs=None,
):
"""
Initialize a test case with complete configuration
......@@ -248,6 +260,8 @@ class TestRunner:
"infinicore_total": 0.0,
"per_test_case": {}, # Store timing per test case
}
# Store test results
self.test_results = []
def run_tests(self, devices, test_func, test_type="Test"):
"""
......@@ -270,25 +284,30 @@ class TestRunner:
try:
print(f"{test_case}")
# Execute test and get result status
success, status = test_func(device, test_case, self.config)
# Execute test and get TestResult object
test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
# Handle different test statuses
if status == "passed":
# Handle different test statuses based on return_code
if test_result.return_code == 0: # Success
self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}"
)
print(f"\033[92m✓\033[0m Passed")
elif status == "skipped":
# Test was skipped due to both operators not being implemented
elif test_result.return_code == -1:
fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - Test terminated in verbose mode."
self.failed_tests.append(fail_msg)
elif test_result.return_code == -2: # Skipped
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
self.skipped_tests.append(skip_msg)
elif status == "partial":
# Test was partially executed (one operator not implemented)
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
elif test_result.return_code == -3: # Partial
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
self.partial_tests.append(partial_msg)
print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
# Failed tests are handled in the exception handler below
if self.config.verbose and test_result.return_code != 0:
return False
except Exception as e:
error_msg = (
......@@ -297,6 +316,15 @@ class TestRunner:
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
success=False,
return_code=-1,
error_message=str(e),
test_case=test_case,
device=device
)
self.test_results.append(failed_result)
# In verbose mode, print full traceback and stop execution
if self.config.verbose:
traceback.print_exc()
......@@ -305,8 +333,7 @@ class TestRunner:
if self.config.debug:
raise
# Return True if no tests failed (skipped/partial tests don't count as failures)
return len(self.failed_tests) == 0
return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
def print_summary(self):
"""
......@@ -377,6 +404,10 @@ class TestRunner:
)
print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x")
def get_test_results(self):
"""Get all test results"""
return self.test_results
class BaseOperatorTest(ABC):
"""Base operator test"""
......@@ -480,12 +511,18 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
tuple: (success, status) where:
success: bool indicating if test passed
status: str describing test status ("passed", "skipped", "partial")
TestResult: Test result object containing status and timing information
"""
device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
device=device
)
# Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
......@@ -559,7 +596,10 @@ class BaseOperatorTest(ABC):
except NotImplementedError:
if config.verbose:
traceback.print_exc()
return False # Stop test execution immediately
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "torch_operator not implemented"
return test_result
torch_implemented = False
torch_result = None
......@@ -570,26 +610,24 @@ class BaseOperatorTest(ABC):
except NotImplementedError:
if config.verbose:
traceback.print_exc()
return False # Stop test execution immediately
# Return test result immediately in verbose mode
test_result.return_code = -1
test_result.error_message = "infinicore_operator not implemented"
return test_result
infini_implemented = False
infini_result = None
# Skip if neither operator is implemented
if not torch_implemented and not infini_implemented:
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
return False, "skipped"
test_result.return_code = -2 # Skipped
return test_result
# Single operator execution without comparison
if not torch_implemented or not infini_implemented:
missing_op = (
"torch_operator" if not torch_implemented else "infinicore_operator"
)
print(
f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
)
test_result.return_code = -3 # Partial
# Run benchmarking for partial tests if enabled
if config.bench:
self._run_benchmarking(
torch_time, infini_time = self._run_benchmarking(
config,
device_str,
torch_implemented,
......@@ -601,8 +639,9 @@ class BaseOperatorTest(ABC):
test_case.output_count,
comparison_target,
)
return False, "partial"
test_result.torch_time = torch_time
test_result.infini_time = infini_time
return test_result
# ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC
# ==========================================================================
......@@ -711,7 +750,7 @@ class BaseOperatorTest(ABC):
# UNIFIED BENCHMARKING LOGIC
# ==========================================================================
if config.bench:
self._run_benchmarking(
torch_time, infini_time = self._run_benchmarking(
config,
device_str,
True,
......@@ -723,9 +762,13 @@ class BaseOperatorTest(ABC):
test_case.output_count,
comparison_target,
)
test_result.torch_time = torch_time
test_result.infini_time = infini_time
# Test passed successfully
return True, "passed"
test_result.success = True
test_result.return_code = 0
return test_result
def _run_benchmarking(
self,
......@@ -742,8 +785,10 @@ class BaseOperatorTest(ABC):
):
"""
Unified benchmarking logic with timing accumulation
"""
Returns:
tuple: (torch_time, infini_time) timing results
"""
# Initialize timing variables
torch_time = 0.0
infini_time = 0.0
......@@ -809,3 +854,5 @@ class BaseOperatorTest(ABC):
# Accumulate total times
config._test_runner.benchmark_times["torch_total"] += torch_time
config._test_runner.benchmark_times["infinicore_total"] += infini_time
return torch_time, infini_time
......@@ -100,8 +100,9 @@ Examples:
# Device options using shared hardware info
hardware_group = get_hardware_args_group(parser)
args, unknown = parser.parse_known_args()
return parser.parse_args()
return args
def get_test_devices(args):
......
......@@ -21,7 +21,9 @@ class GenericTestRunner:
"""Execute the complete test suite
Returns:
bool: True if all tests passed or were skipped/partial, False if any tests failed
tuple: (success, test_runner) where:
success: bool indicating if all tests passed or were skipped/partial
test_runner: TestRunner instance with test results
"""
config = TestConfig(
debug=self.args.debug,
......@@ -51,7 +53,7 @@ class GenericTestRunner:
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
return has_no_failures and summary_passed
return (has_no_failures and summary_passed), runner
def run_and_exit(self):
"""Run tests and exit with appropriate status code
......@@ -60,5 +62,5 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success = self.run()
success, runner = self.run()
sys.exit(0 if success else 1)
......@@ -133,9 +133,9 @@ class OpTest(BaseOperatorTest):
"""PyTorch ELU implementation"""
return torch.nn.functional.elu(*args, **kwargs)
def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
"""InfiniCore ELU implementation"""
return None
# def infinicore_operator(self, x, alpha=1.0, out=None, **kwargs):
# """InfiniCore ELU implementation"""
# return None
def main():
......
......@@ -103,7 +103,7 @@ def parse_test_cases():
return test_cases
class MultiMarginLossOpTest(BaseOperatorTest):
class OpTest(BaseOperatorTest):
"""MultiMarginLoss operator test with device handling"""
def __init__(self):
......@@ -116,9 +116,9 @@ class MultiMarginLossOpTest(BaseOperatorTest):
"""PyTorch multi_margin_loss implementation with device handling"""
return F.multi_margin_loss(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore multi_margin_loss implementation"""
return None
# def infinicore_operator(self, *args, **kwargs):
# """InfiniCore multi_margin_loss implementation"""
# return None
def main():
......
import os
import sys
import subprocess
import argparse
from pathlib import Path
from typing import Dict, Tuple, List
import importlib.util
from framework import get_hardware_args_group
def find_ops_directory(location=None):
......@@ -58,9 +59,59 @@ def get_available_operators(ops_dir):
return sorted(operators)
def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
def import_operator_test(test_file_path):
"""
Import an operator test module and return the test class instance.
Args:
test_file_path: Path to the test file
Returns:
tuple: (success, test_instance_or_error)
"""
try:
# Create a unique module name
module_name = f"op_test_{test_file_path.stem}"
# Load the module from file
spec = importlib.util.spec_from_file_location(module_name, test_file_path)
if spec is None or spec.loader is None:
return False, f"Could not load module from {test_file_path}"
module = importlib.util.module_from_spec(spec)
# Add the module to sys.modules
sys.modules[module_name] = module
# Execute the module
spec.loader.exec_module(module)
# Find the test class (usually named OpTest)
test_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and hasattr(attr, "__bases__")
and any("BaseOperatorTest" in str(base) for base in attr.__bases__)
):
test_class = attr
break
if test_class is None:
return False, f"No test class found in {test_file_path}"
# Create an instance
test_instance = test_class()
return True, test_instance
except Exception as e:
return False, f"Error importing {test_file_path}: {str(e)}"
def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False):
"""
Run all operator test scripts in the ops directory.
Run all operator test scripts in the ops directory using direct import.
Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
......@@ -68,7 +119,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
extra_args (list, optional): Extra command line arguments to pass to test scripts.
Returns:
dict: Results dictionary with test names as keys and (success, return_code, stdout, stderr) as values.
dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
"""
if ops_dir is None:
ops_dir = find_ops_directory()
......@@ -122,11 +173,6 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
results = {}
# Check if verbose mode is enabled
verbose_mode = extra_args and "--verbose" in extra_args
# Check if bench mode is enabled for cumulative timing
bench_mode = extra_args and "--bench" in extra_args
cumulative_timing = {
"total_torch_time": 0.0,
"total_infinicore_time": 0.0,
......@@ -137,117 +183,160 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
test_name = test_file.stem
try:
# Run the test script - use the absolute path and run from current directory
cmd = [sys.executable, str(test_file.absolute())]
# Import and run the test directly
success, test_instance_or_error = import_operator_test(test_file)
if not success:
print(f"💥 {test_name}: ERROR - {test_instance_or_error}")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": test_instance_or_error,
"test_runner": None,
"stdout": "",
"stderr": test_instance_or_error,
}
continue
# Add extra arguments if provided
if extra_args:
cmd.extend(extra_args)
# Get the test runner class from the module
test_module = sys.modules[f"op_test_{test_file.stem}"]
if not hasattr(test_module, "GenericTestRunner"):
print(f"💥 {test_name}: ERROR - No GenericTestRunner found")
results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": "No GenericTestRunner found",
"test_runner": None,
"stdout": "",
"stderr": "No GenericTestRunner found",
}
continue
result = subprocess.run(
cmd,
capture_output=True, # Capture output to analyze
text=True,
)
# Create and run the test runner
test_runner_class = test_module.GenericTestRunner
runner_instance = test_runner_class(test_instance_or_error.__class__)
# Analyze output to determine test status
stdout_lower = result.stdout.lower()
stderr_lower = result.stderr.lower()
# Temporarily redirect stdout to capture output
from io import StringIO
# Check for operator not implemented patterns
if (
"all tests passed!" in stdout_lower
and "success rate: 100.0%" in stdout_lower
):
success = True
returncode = 0
elif "both operators not implemented" in stdout_lower:
# Both operators not implemented - skipped test
success = False # Not a failure, but skipped
returncode = -2 # Special code for skipped
elif "operator not implemented" in stdout_lower:
# One operator not implemented - partial test
success = False # Not fully successful
returncode = -3 # Special code for partial
else:
success = False
returncode = -1
results[test_name] = (
success,
returncode,
result.stdout,
result.stderr,
)
stdout_capture = StringIO()
stderr_capture = StringIO()
# Print the output from the test script
print(f"\n{'='*60}")
print(f"TEST: {test_name}")
print(f"{'='*60}")
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = stdout_capture
sys.stderr = stderr_capture
try:
# Run the test
test_success, test_runner = runner_instance.run()
if result.stdout:
print(result.stdout.rstrip())
# Get captured output
stdout_output = stdout_capture.getvalue()
stderr_output = stderr_capture.getvalue()
if result.stderr:
# Restore stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
# Print the captured output
if stdout_output:
print(stdout_output.rstrip())
if stderr_output:
print("\nSTDERR:")
print(result.stderr.rstrip())
print(stderr_output.rstrip())
# Enhanced status display
if returncode == -2:
status_icon = "⏭️"
status_text = "SKIPPED"
elif returncode == -3:
status_icon = "⚠️"
status_text = "PARTIAL"
elif success:
# Analyze test results
test_results = test_runner.get_test_results() if test_runner else []
# Determine overall test status
if test_success:
return_code = 0
status_icon = "✅"
status_text = "PASSED"
else:
# Check if there are any failed tests
has_failures = any(
result.return_code == -1 for result in test_results
)
has_partial = any(
result.return_code == -3 for result in test_results
)
has_skipped = any(
result.return_code == -2 for result in test_results
)
if has_failures:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
elif has_partial:
return_code = -3
status_icon = "⚠️"
status_text = "PARTIAL"
elif has_skipped:
return_code = -2
status_icon = "⏭️"
status_text = "SKIPPED"
else:
return_code = -1
status_icon = "❌"
status_text = "FAILED"
# Calculate timing
torch_time = sum(result.torch_time for result in test_results)
infini_time = sum(result.infini_time for result in test_results)
results[test_name] = {
"success": test_success,
"return_code": return_code,
"torch_time": torch_time,
"infini_time": infini_time,
"error_message": "",
"test_runner": test_runner,
"stdout": stdout_output,
"stderr": stderr_output,
}
print(
f"{status_icon} {test_name}: {status_text} (return code: {returncode})"
f"{status_icon} {test_name}: {status_text} (return code: {return_code})"
)
# Extract benchmark timing if in bench mode
if bench_mode and success:
# Look for benchmark summary in stdout
lines = result.stdout.split("\n")
torch_time = 0.0
infini_time = 0.0
for line in lines:
if "PyTorch Total Time:" in line:
try:
# Extract time value (e.g., "PyTorch Total Time: 123.456 ms")
torch_time = (
float(line.split(":")[1].strip().split()[0]) / 1000.0
) # Convert to seconds
except:
pass
elif "InfiniCore Total Time:" in line:
try:
infini_time = (
float(line.split(":")[1].strip().split()[0]) / 1000.0
) # Convert to seconds
except:
pass
if bench and test_success and return_code == 0:
cumulative_timing["total_torch_time"] += torch_time
cumulative_timing["total_infinicore_time"] += infini_time
cumulative_timing["operators_tested"] += 1
except Exception as e:
# Restore stdout/stderr in case of exception
sys.stdout = old_stdout
sys.stderr = old_stderr
raise e
# In verbose mode, stop execution on first failure
if verbose_mode and not success and returncode not in [-2, -3]:
if verbose and not test_success and return_code != 0:
break
except Exception as e:
print(f"💥 {test_name}: ERROR - {str(e)}")
results[test_name] = (False, -1, "", str(e))
results[test_name] = {
"success": False,
"return_code": -1,
"torch_time": 0.0,
"infini_time": 0.0,
"error_message": str(e),
"test_runner": None,
"stdout": "",
"stderr": str(e),
}
# In verbose mode, stop execution on any exception
if verbose_mode:
if verbose:
print(f"\n{'!'*60}")
print(
f"VERBOSE MODE: Stopping execution due to exception in {test_name}"
......@@ -259,7 +348,7 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
def print_summary(
results, verbose_mode=False, total_expected_tests=0, cumulative_timing=None
results, verbose=False, total_expected_tests=0, cumulative_timing=None
):
"""Print a comprehensive summary of test results including benchmark data."""
print(f"\n{'='*80}")
......@@ -280,14 +369,15 @@ def print_summary(
skipped_operators = [] # Store skipped operator names
partial_operators = [] # Store partial operator names
for test_name, (success, returncode, stdout, stderr) in results.items():
if success:
for test_name, result_data in results.items():
return_code = result_data["return_code"]
if return_code == 0:
passed += 1
passed_operators.append(test_name)
elif returncode == -2: # Special code for skipped tests
elif return_code == -2: # Special code for skipped tests
skipped += 1
skipped_operators.append(test_name)
elif returncode == -3: # Special code for partial tests
elif return_code == -3: # Special code for partial tests
partial += 1
partial_operators.append(test_name)
else:
......@@ -316,10 +406,10 @@ def print_summary(
print("BENCHMARK SUMMARY:")
print(f" Operators Tested: {cumulative_timing['operators_tested']}")
print(
f" Total PyTorch Time: {cumulative_timing['total_torch_time'] * 1000:.3f} ms"
f" PyTorch Total Time: {cumulative_timing['total_torch_time'] * 1000:12.3f} ms"
)
print(
f" Total InfiniCore Time: {cumulative_timing['total_infinicore_time'] * 1000:.3f} ms"
f" InfiniCore Total Time: {cumulative_timing['total_infinicore_time'] * 1000:12.3f} ms"
)
print(f"{'-'*40}")
......@@ -361,7 +451,7 @@ def print_summary(
success_rate = passed / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if verbose_mode and total < total_expected_tests:
if verbose and total < total_expected_tests:
print(f"\n💡 Verbose mode: Execution stopped after first failure")
print(f" {total_expected_tests - total} tests were not executed")
......@@ -505,17 +595,18 @@ def main():
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose mode to stop on first error with full traceback (passed to individual tests)",
help="Enable verbose mode to stop on first error with full traceback",
)
parser.add_argument(
"--bench",
action="store_true",
help="Enable bench mode to show performance data",
)
from framework import get_hardware_args_group
if "-h" in sys.argv or "--help" in sys.argv:
get_hardware_args_group(parser)
# Parse known args first, leave the rest for the test scripts
args, unknown_args = parser.parse_known_args()
get_hardware_args_group(parser)
# Handle list command
if args.list:
......@@ -536,10 +627,6 @@ def main():
print(f"Error: Ops directory '{ops_dir}' does not exist.")
sys.exit(1)
# Add verbose flag to extra arguments if specified
if args.verbose and "--verbose" not in unknown_args:
unknown_args.append("--verbose")
# Show what extra arguments will be passed
if unknown_args:
print(f"Passing extra arguments to test scripts: {unknown_args}")
......@@ -584,7 +671,8 @@ def main():
results, cumulative_timing = run_all_op_tests(
ops_dir=ops_dir,
specific_ops=args.ops,
extra_args=unknown_args,
bench=args.bench,
verbose=args.verbose,
)
# Print summary and exit with appropriate code
......@@ -594,7 +682,7 @@ def main():
# Check if there were any tests with missing implementations
has_missing_implementations = any(
returncode in [-2, -3] for _, (_, returncode, _, _) in results.items()
result_data["return_code"] in [-2, -3] for result_data in results.values()
)
if all_passed and has_missing_implementations:
......@@ -607,8 +695,8 @@ def main():
)
failed_ops = [
name
for name, (success, _, _, _) in results.items()
if not success and name in results
for name, result_data in results.items()
if result_data["return_code"] == -1
]
for op in failed_ops[:3]: # Show first 3 failed operators
print(f" python {ops_dir / (op + '.py')} --verbose")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment