Commit 7542c51d authored by wooway777's avatar wooway777 Committed by MaYuhang
Browse files

issue/573 - distinguishing unimplemented cases

parent 3d139316
......@@ -176,6 +176,9 @@ class TestConfig:
self.num_iterations = num_iterations
# In base.py - update the TestRunner class
class TestRunner:
"""Test runner"""
......@@ -183,8 +186,24 @@ class TestRunner:
self.test_cases = test_cases
self.config = test_config
self.failed_tests = []
self.skipped_tests = [] # Track skipped tests (both operators not implemented)
self.partial_tests = [] # Track partial tests (one operator not implemented)
self.passed_tests = (
[]
) # Track passed tests (both operators implemented and passed)
def run_tests(self, devices, test_func, test_type="Test"):
"""
Run tests on specified devices
Args:
devices: List of devices to test on
test_func: Test function to execute
test_type: Type of test for display purposes
Returns:
bool: True if no tests failed, False otherwise
"""
for device in devices:
print(f"\n{'='*60}")
print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
......@@ -194,26 +213,102 @@ class TestRunner:
try:
print(f"{test_case}")
test_func(device, test_case, self.config)
print(f"\033[92m✓\033[0m Passed")
# Execute test and get result status
success, status = test_func(device, test_case, self.config)
# Handle different test statuses
if status == "passed":
self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}"
)
print(f"\033[92m✓\033[0m Passed")
elif status == "skipped":
# Test was skipped due to both operators not being implemented
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
self.skipped_tests.append(skip_msg)
print(
f"\033[93m⚠\033[0m Skipped - both operators not implemented"
)
elif status == "partial":
# Test was partially executed (one operator not implemented)
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
self.partial_tests.append(partial_msg)
print(
f"\033[93m⚠\033[0m Partial - one operator not implemented"
)
# Failed tests are handled in the exception handler below
except Exception as e:
error_msg = f"Error: {e}"
error_msg = (
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
)
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
# Return True if no tests failed (skipped/partial tests don't count as failures)
return len(self.failed_tests) == 0
def print_summary(self):
"""
Print test execution summary
Returns:
bool: True if no tests failed, False otherwise
"""
total_tests = len(self.test_cases)
passed_count = len(self.passed_tests)
skipped_count = len(self.skipped_tests)
partial_count = len(self.partial_tests)
failed_count = len(self.failed_tests)
print(f"\n{'='*60}")
print("TEST SUMMARY")
print(f"{'='*60}")
print(f"Total tests: {total_tests}")
print(f"\033[92mPassed: {passed_count}\033[0m")
# Display partial tests (one operator not implemented)
if self.partial_tests:
print(
f"\033[93mPartial (one operator not implemented): {partial_count}\033[0m"
)
for test in self.partial_tests:
print(f" - {test}")
# Display skipped tests (both operators not implemented)
if self.skipped_tests:
print(
f"\033[93mSkipped (both operators not implemented): {skipped_count}\033[0m"
)
for test in self.skipped_tests:
print(f" - {test}")
# Display failed tests
if self.failed_tests:
print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m")
print(f"\033[91mFailed: {failed_count}\033[0m")
for failure in self.failed_tests:
print(f" - {failure}")
# Return False only if there are actual test failures
return False
else:
print("\n\033[92mAll tests passed!\033[0m")
return True
# Calculate success rate based on actual executed tests
executed_tests = passed_count + partial_count + failed_count
if executed_tests > 0:
success_rate = passed_count / executed_tests * 100
print(f"Success rate: {success_rate:.1f}%")
# If there are skipped or partial tests, show appropriate message
if self.skipped_tests or self.partial_tests:
print(
f"\n\033[93mTests completed with some implementations missing\033[0m"
)
return True # Skipped/partial tests don't count as failures
else:
print(f"\n\033[92mAll tests passed!\033[0m")
return True
class BaseOperatorTest(ABC):
......@@ -282,7 +377,19 @@ class BaseOperatorTest(ABC):
return inputs, kwargs
def run_test(self, device, test_case, config):
"""Unified test execution flow"""
"""
Unified test execution flow
Args:
device: Device to test on
test_case: Test case configuration
config: Test configuration
Returns:
tuple: (success, status) where:
success: bool indicating if test passed
status: str describing test status ("passed", "skipped", "partial")
"""
device_str = torch_device_map[device]
# Prepare inputs and kwargs with actual tensors
......@@ -358,8 +465,8 @@ class BaseOperatorTest(ABC):
# Skip if neither operator is implemented
if not torch_implemented and not infini_implemented:
print(f" Both operators not implemented - test skipped")
return
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
return False, "skipped"
# Single operator execution without comparison
if not torch_implemented or not infini_implemented:
......@@ -367,7 +474,7 @@ class BaseOperatorTest(ABC):
"torch_operator" if not torch_implemented else "infinicore_operator"
)
print(
f" {missing_op} not implemented - running single operator without comparison"
f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
)
if config.bench:
......@@ -383,7 +490,7 @@ class BaseOperatorTest(ABC):
test_case.output_count,
comparison_target,
)
return
return False, "partial"
# ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC
......@@ -443,7 +550,10 @@ class BaseOperatorTest(ABC):
else:
print(f"✅ Output {i} comparison passed")
assert all_valid, f"Multiple outputs comparison failed for {test_case}"
if not all_valid:
raise AssertionError(
f"Multiple outputs comparison failed for {test_case}"
)
# ==========================================================================
# SINGLE OUTPUT COMPARISON LOGIC
......@@ -483,7 +593,8 @@ class BaseOperatorTest(ABC):
)
is_valid = compare_fn(infini_comparison, torch_comparison)
assert is_valid, f"Result comparison failed for {test_case}"
if not is_valid:
raise AssertionError(f"Result comparison failed for {test_case}")
# ==========================================================================
# UNIFIED BENCHMARKING LOGIC
......@@ -502,6 +613,9 @@ class BaseOperatorTest(ABC):
comparison_target,
)
# Test passed successfully
return True, "passed"
def _run_benchmarking(
self,
config,
......
......@@ -18,7 +18,11 @@ class GenericTestRunner:
self.args = get_args()
def run(self):
"""Execute the complete test suite"""
"""Execute the complete test suite
Returns:
bool: True if all tests passed or were skipped/partial, False if any tests failed
"""
config = TestConfig(
debug=self.args.debug,
bench=self.args.bench,
......@@ -29,18 +33,27 @@ class GenericTestRunner:
runner = TestRunner(self.operator_test.test_cases, config)
devices = get_test_devices(self.args)
# Run unified tests
all_passed = runner.run_tests(
# Run unified tests - returns True if no tests failed
# (skipped/partial tests don't count as failures)
has_no_failures = runner.run_tests(
devices, self.operator_test.run_test, self.operator_test.operator_name
)
# Print summary
# Print summary and get final result
# summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed = runner.print_summary()
all_passed = all_passed and summary_passed
return all_passed
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
return has_no_failures and summary_passed
def run_and_exit(self):
"""Run tests and exit with appropriate status code"""
"""Run tests and exit with appropriate status code
Exit codes:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success = self.run()
sys.exit(0 if success else 1)
......@@ -145,16 +145,37 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
result = subprocess.run(
cmd,
cwd=ops_dir,
stdout=None,
stderr=None,
capture_output=True, # Capture output to analyze
text=True,
)
success = result.returncode == 0
# Analyze output to determine test status
stdout_lower = result.stdout.lower()
stderr_lower = result.stderr.lower()
# Check for operator not implemented patterns
if "not implemented" in stdout_lower or "not implemented" in stderr_lower:
if "both operators not implemented" in stdout_lower:
# Both operators not implemented - skipped test
success = True # Not a failure, but skipped
returncode = -2 # Special code for skipped
elif "one operator not implemented" in stdout_lower:
# One operator not implemented - partial test
success = False # Not fully successful
returncode = -3 # Special code for partial
else:
# General not implemented case
success = result.returncode == 0
returncode = result.returncode
else:
success = result.returncode == 0
returncode = result.returncode
results[test_name] = (
success,
result.returncode,
"",
"",
returncode,
result.stdout,
result.stderr,
)
# Print the output from the test script
......@@ -169,9 +190,22 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, extra_args=None):
print("\nSTDERR:")
print(result.stderr.rstrip())
status_icon = "✅" if success else "❌"
# Enhanced status display
if returncode == -2:
status_icon = "⏭️"
status_text = "SKIPPED (operators not implemented)"
elif returncode == -3:
status_icon = "⚠️"
status_text = "PARTIAL (one operator not implemented)"
elif success:
status_icon = "✅"
status_text = "PASSED"
else:
status_icon = "❌"
status_text = "FAILED"
print(
f"{status_icon} {test_name}: {'PASSED' if success else 'FAILED'} (return code: {result.returncode})"
f"{status_icon} {test_name}: {status_text} (return code: {returncode})"
)
except Exception as e:
......@@ -191,36 +225,54 @@ def print_summary(results):
print("No tests were run.")
return False
passed = sum(1 for success, _, _, _ in results.values() if success)
# Count different types of results
passed = 0
failed = 0
skipped = 0
partial = 0
for test_name, (success, returncode, stdout, stderr) in results.items():
if success:
passed += 1
elif returncode == -2: # Special code for skipped tests
skipped += 1
elif returncode == -3: # Special code for partial tests
partial += 1
else:
failed += 1
total = len(results)
failed_tests = [name for name, (success, _, _, _) in results.items() if not success]
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {total - passed}")
print(f"Failed: {failed}")
if total > 0:
success_rate = passed / total * 100
print(f"Success rate: {success_rate:.1f}%")
if skipped > 0:
print(f"Skipped (operators not implemented): {skipped}")
if not failed_tests:
print("\n🎉 All tests passed!")
if partial > 0:
print(f"Partial (one operator not implemented): {partial}")
if total > 0:
# Calculate success rate based on executed tests only
executed_tests = passed + failed + partial
if executed_tests > 0:
success_rate = passed / executed_tests * 100
print(f"Success rate: {success_rate:.1f}%")
if failed == 0:
if skipped > 0 or partial > 0:
print(f"\n⚠️ Tests completed with some operators not implemented")
print(f" - {skipped} tests skipped (both operators not implemented)")
print(f" - {partial} tests partial (one operator not implemented)")
else:
print(f"\n🎉 All tests passed!")
return True
else:
print(f"\n{len(failed_tests)} tests failed:")
for test_name in failed_tests:
success, returncode, stdout, stderr = results[test_name]
print(f" - {test_name} (return code: {returncode})")
# Print brief error info for failed tests
if stderr:
error_lines = stderr.strip().split("\n")
if error_lines:
# Take first meaningful error line
for line in error_lines:
if line.strip() and not line.startswith("Warning:"):
print(f" Error: {line.strip()}")
break
print(f"\n{failed} tests failed:")
for test_name, (success, returncode, stdout, stderr) in results.items():
if not success and returncode not in [-2, -3]: # Not skipped or partial
print(f" - {test_name} (return code: {returncode})")
return False
......@@ -406,6 +458,16 @@ def main():
# Print summary and exit with appropriate code
all_passed = print_summary(results)
# Check if there were any tests with missing implementations
has_missing_implementations = any(
returncode in [-2, -3] for _, (_, returncode, _, _) in results.items()
)
if all_passed and has_missing_implementations:
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
sys.exit(0 if all_passed else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment