Commit 7542c51d authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/573 - distinguishing unimplemented cases

parent 3d139316
...@@ -176,6 +176,9 @@ class TestConfig: ...@@ -176,6 +176,9 @@ class TestConfig:
self.num_iterations = num_iterations self.num_iterations = num_iterations
# In base.py - update the TestRunner class
class TestRunner: class TestRunner:
"""Test runner""" """Test runner"""
...@@ -183,8 +186,24 @@ class TestRunner: ...@@ -183,8 +186,24 @@ class TestRunner:
self.test_cases = test_cases self.test_cases = test_cases
self.config = test_config self.config = test_config
self.failed_tests = [] self.failed_tests = []
self.skipped_tests = [] # Track skipped tests (both operators not implemented)
self.partial_tests = [] # Track partial tests (one operator not implemented)
self.passed_tests = (
[]
) # Track passed tests (both operators implemented and passed)
def run_tests(self, devices, test_func, test_type="Test"): def run_tests(self, devices, test_func, test_type="Test"):
"""
Run tests on specified devices
Args:
devices: List of devices to test on
test_func: Test function to execute
test_type: Type of test for display purposes
Returns:
bool: True if no tests failed, False otherwise
"""
for device in devices: for device in devices:
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"Testing {test_type} on {InfiniDeviceNames[device]}") print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
...@@ -194,26 +213,102 @@ class TestRunner: ...@@ -194,26 +213,102 @@ class TestRunner:
try: try:
print(f"{test_case}") print(f"{test_case}")
test_func(device, test_case, self.config) # Execute test and get result status
print(f"\033[92m✓\033[0m Passed") success, status = test_func(device, test_case, self.config)
# Handle different test statuses
if status == "passed":
self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}"
)
print(f"\033[92m✓\033[0m Passed")
elif status == "skipped":
# Test was skipped due to both operators not being implemented
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
self.skipped_tests.append(skip_msg)
print(
f"\033[93m⚠\033[0m Skipped - both operators not implemented"
)
elif status == "partial":
# Test was partially executed (one operator not implemented)
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
self.partial_tests.append(partial_msg)
print(
f"\033[93m⚠\033[0m Partial - one operator not implemented"
)
# Failed tests are handled in the exception handler below
except Exception as e: except Exception as e:
error_msg = f"Error: {e}" error_msg = (
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
)
print(f"\033[91m✗\033[0m {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
if self.config.debug: if self.config.debug:
raise raise
# Return True if no tests failed (skipped/partial tests don't count as failures)
return len(self.failed_tests) == 0 return len(self.failed_tests) == 0
def print_summary(self): def print_summary(self):
"""
Print test execution summary
Returns:
bool: True if no tests failed, False otherwise
"""
total_tests = len(self.test_cases)
passed_count = len(self.passed_tests)
skipped_count = len(self.skipped_tests)
partial_count = len(self.partial_tests)
failed_count = len(self.failed_tests)
print(f"\n{'='*60}")
print("TEST SUMMARY")
print(f"{'='*60}")
print(f"Total tests: {total_tests}")
print(f"\033[92mPassed: {passed_count}\033[0m")
# Display partial tests (one operator not implemented)
if self.partial_tests:
print(
f"\033[93mPartial (one operator not implemented): {partial_count}\033[0m"
)
for test in self.partial_tests:
print(f" - {test}")
# Display skipped tests (both operators not implemented)
if self.skipped_tests:
print(
f"\033[93mSkipped (both operators not implemented): {skipped_count}\033[0m"
)
for test in self.skipped_tests:
print(f" - {test}")
# Display failed tests
if self.failed_tests: if self.failed_tests:
print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m") print(f"\033[91mFailed: {failed_count}\033[0m")
for failure in self.failed_tests: for failure in self.failed_tests:
print(f" - {failure}") print(f" - {failure}")
# Return False only if there are actual test failures
return False return False
else: else:
print("\n\033[92mAll tests passed!\033[0m") # Calculate success rate based on actual executed tests
return True executed_tests = passed_count + partial_count + failed_count
if executed_tests > 0:
success_rate = passed_count / executed_tests * 100
print(f"Success rate: {success_rate:.1f}%")
# If there are skipped or partial tests, show appropriate message
if self.skipped_tests or self.partial_tests:
print(
f"\n\033[93mTests completed with some implementations missing\033[0m"
)
return True # Skipped/partial tests don't count as failures
else:
print(f"\n\033[92mAll tests passed!\033[0m")
return True
class BaseOperatorTest(ABC): class BaseOperatorTest(ABC):
...@@ -282,7 +377,19 @@ class BaseOperatorTest(ABC): ...@@ -282,7 +377,19 @@ class BaseOperatorTest(ABC):
return inputs, kwargs return inputs, kwargs
def run_test(self, device, test_case, config): def run_test(self, device, test_case, config):
"""Unified test execution flow""" """
Unified test execution flow
Args:
device: Device to test on
test_case: Test case configuration
config: Test configuration
Returns:
tuple: (success, status) where:
success: bool indicating if test passed
status: str describing test status ("passed", "skipped", "partial")
"""
device_str = torch_device_map[device] device_str = torch_device_map[device]
# Prepare inputs and kwargs with actual tensors # Prepare inputs and kwargs with actual tensors
...@@ -358,8 +465,8 @@ class BaseOperatorTest(ABC): ...@@ -358,8 +465,8 @@ class BaseOperatorTest(ABC):
# Skip if neither operator is implemented # Skip if neither operator is implemented
if not torch_implemented and not infini_implemented: if not torch_implemented and not infini_implemented:
print(f" Both operators not implemented - test skipped") print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
return return False, "skipped"
# Single operator execution without comparison # Single operator execution without comparison
if not torch_implemented or not infini_implemented: if not torch_implemented or not infini_implemented:
...@@ -367,7 +474,7 @@ class BaseOperatorTest(ABC): ...@@ -367,7 +474,7 @@ class BaseOperatorTest(ABC):
"torch_operator" if not torch_implemented else "infinicore_operator" "torch_operator" if not torch_implemented else "infinicore_operator"
) )
print( print(
f" {missing_op} not implemented - running single operator without comparison" f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
) )
if config.bench: if config.bench:
...@@ -383,7 +490,7 @@ class BaseOperatorTest(ABC): ...@@ -383,7 +490,7 @@ class BaseOperatorTest(ABC):
test_case.output_count, test_case.output_count,
comparison_target, comparison_target,
) )
return return False, "partial"
# ========================================================================== # ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC # MULTIPLE OUTPUTS COMPARISON LOGIC
...@@ -443,7 +550,10 @@ class BaseOperatorTest(ABC): ...@@ -443,7 +550,10 @@ class BaseOperatorTest(ABC):
else: else:
print(f"✅ Output {i} comparison passed") print(f"✅ Output {i} comparison passed")
assert all_valid, f"Multiple outputs comparison failed for {test_case}" if not all_valid:
raise AssertionError(
f"Multiple outputs comparison failed for {test_case}"
)
# ========================================================================== # ==========================================================================
# SINGLE OUTPUT COMPARISON LOGIC # SINGLE OUTPUT COMPARISON LOGIC
...@@ -483,7 +593,8 @@ class BaseOperatorTest(ABC): ...@@ -483,7 +593,8 @@ class BaseOperatorTest(ABC):
) )
is_valid = compare_fn(infini_comparison, torch_comparison) is_valid = compare_fn(infini_comparison, torch_comparison)
assert is_valid, f"Result comparison failed for {test_case}" if not is_valid:
raise AssertionError(f"Result comparison failed for {test_case}")
# ========================================================================== # ==========================================================================
# UNIFIED BENCHMARKING LOGIC # UNIFIED BENCHMARKING LOGIC
...@@ -502,6 +613,9 @@ class BaseOperatorTest(ABC): ...@@ -502,6 +613,9 @@ class BaseOperatorTest(ABC):
comparison_target, comparison_target,
) )
# Test passed successfully
return True, "passed"
def _run_benchmarking( def _run_benchmarking(
self, self,
config, config,
......
...@@ -18,7 +18,11 @@ class GenericTestRunner: ...@@ -18,7 +18,11 @@ class GenericTestRunner:
self.args = get_args() self.args = get_args()
def run(self): def run(self):
"""Execute the complete test suite""" """Execute the complete test suite
Returns:
bool: True if all tests passed or were skipped/partial, False if any tests failed
"""
config = TestConfig( config = TestConfig(
debug=self.args.debug, debug=self.args.debug,
bench=self.args.bench, bench=self.args.bench,
...@@ -29,18 +33,27 @@ class GenericTestRunner: ...@@ -29,18 +33,27 @@ class GenericTestRunner:
runner = TestRunner(self.operator_test.test_cases, config) runner = TestRunner(self.operator_test.test_cases, config)
devices = get_test_devices(self.args) devices = get_test_devices(self.args)
# Run unified tests # Run unified tests - returns True if no tests failed
all_passed = runner.run_tests( # (skipped/partial tests don't count as failures)
has_no_failures = runner.run_tests(
devices, self.operator_test.run_test, self.operator_test.operator_name devices, self.operator_test.run_test, self.operator_test.operator_name
) )
# Print summary # Print summary and get final result
# summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed = runner.print_summary() summary_passed = runner.print_summary()
all_passed = all_passed and summary_passed
return all_passed # Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
return has_no_failures and summary_passed
def run_and_exit(self): def run_and_exit(self):
"""Run tests and exit with appropriate status code""" """Run tests and exit with appropriate status code
Exit codes:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success = self.run() success = self.run()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)
...@@ -145,16 +145,37 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -145,16 +145,37 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
result = subprocess.run( result = subprocess.run(
cmd, cmd,
cwd=ops_dir, cwd=ops_dir,
stdout=None, capture_output=True, # Capture output to analyze
stderr=None, text=True,
) )
success = result.returncode == 0 # Analyze output to determine test status
stdout_lower = result.stdout.lower()
stderr_lower = result.stderr.lower()
# Check for operator not implemented patterns
if "not implemented" in stdout_lower or "not implemented" in stderr_lower:
if "both operators not implemented" in stdout_lower:
# Both operators not implemented - skipped test
success = True # Not a failure, but skipped
returncode = -2 # Special code for skipped
elif "one operator not implemented" in stdout_lower:
# One operator not implemented - partial test
success = False # Not fully successful
returncode = -3 # Special code for partial
else:
# General not implemented case
success = result.returncode == 0
returncode = result.returncode
else:
success = result.returncode == 0
returncode = result.returncode
results[test_name] = ( results[test_name] = (
success, success,
result.returncode, returncode,
"", result.stdout,
"", result.stderr,
) )
# Print the output from the test script # Print the output from the test script
...@@ -169,9 +190,22 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None): ...@@ -169,9 +190,22 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
print("\nSTDERR:") print("\nSTDERR:")
print(result.stderr.rstrip()) print(result.stderr.rstrip())
status_icon = "✅" if success else "❌" # Enhanced status display
if returncode == -2:
status_icon = "⏭️"
status_text = "SKIPPED (operators not implemented)"
elif returncode == -3:
status_icon = "⚠️"
status_text = "PARTIAL (one operator not implemented)"
elif success:
status_icon = "✅"
status_text = "PASSED"
else:
status_icon = "❌"
status_text = "FAILED"
print( print(
f"{status_icon} {test_name}: {'PASSED' if success else 'FAILED'} (return code: {result.returncode})" f"{status_icon} {test_name}: {status_text} (return code: {returncode})"
) )
except Exception as e: except Exception as e:
...@@ -191,36 +225,54 @@ def print_summary(results): ...@@ -191,36 +225,54 @@ def print_summary(results):
print("No tests were run.") print("No tests were run.")
return False return False
passed = sum(1 for success, _, _, _ in results.values() if success) # Count different types of results
passed = 0
failed = 0
skipped = 0
partial = 0
for test_name, (success, returncode, stdout, stderr) in results.items():
if success:
passed += 1
elif returncode == -2: # Special code for skipped tests
skipped += 1
elif returncode == -3: # Special code for partial tests
partial += 1
else:
failed += 1
total = len(results) total = len(results)
failed_tests = [name for name, (success, _, _, _) in results.items() if not success]
print(f"Total tests: {total}") print(f"Total tests: {total}")
print(f"Passed: {passed}") print(f"Passed: {passed}")
print(f"Failed: {total - passed}") print(f"Failed: {failed}")
if total > 0: if skipped > 0:
success_rate = passed / total * 100 print(f"Skipped (operators not implemented): {skipped}")
print(f"Success rate: {success_rate:.1f}%")
if not failed_tests: if partial > 0:
print("\n🎉 All tests passed!") print(f"Partial (one operator not implemented): {partial}")
if total > 0:
# Calculate success rate based on executed tests only
executed_tests = passed + failed + partial
if executed_tests > 0:
success_rate = passed / executed_tests * 100
print(f"Success rate: {success_rate:.1f}%")
if failed == 0:
if skipped > 0 or partial > 0:
print(f"\n⚠️ Tests completed with some operators not implemented")
print(f" - {skipped} tests skipped (both operators not implemented)")
print(f" - {partial} tests partial (one operator not implemented)")
else:
print(f"\n🎉 All tests passed!")
return True return True
else: else:
print(f"\n{len(failed_tests)} tests failed:") print(f"\n{failed} tests failed:")
for test_name in failed_tests: for test_name, (success, returncode, stdout, stderr) in results.items():
success, returncode, stdout, stderr = results[test_name] if not success and returncode not in [-2, -3]: # Not skipped or partial
print(f" - {test_name} (return code: {returncode})") print(f" - {test_name} (return code: {returncode})")
# Print brief error info for failed tests
if stderr:
error_lines = stderr.strip().split("\n")
if error_lines:
# Take first meaningful error line
for line in error_lines:
if line.strip() and not line.startswith("Warning:"):
print(f" Error: {line.strip()}")
break
return False return False
...@@ -406,6 +458,16 @@ def main(): ...@@ -406,6 +458,16 @@ def main():
# Print summary and exit with appropriate code # Print summary and exit with appropriate code
all_passed = print_summary(results) all_passed = print_summary(results)
# Check if there were any tests with missing implementations
has_missing_implementations = any(
returncode in [-2, -3] for _, (_, returncode, _, _) in results.items()
)
if all_passed and has_missing_implementations:
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
sys.exit(0 if all_passed else 1) sys.exit(0 if all_passed else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment