Unverified Commit 79b70e58 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/571 DeviceEvent (#583)



* issue/571 - introduced the DeviceEvent feature

---------
Co-authored-by: default avatarJiacheng Huang <huangjiacheng0709@outlook.com>
parent d4738a98
...@@ -55,6 +55,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -55,6 +55,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_KUNLUNRT(xpu_event_record((kunlunEvent_t)event, (kunlunStream_t)stream)); CHECK_KUNLUNRT(xpu_event_record((kunlunEvent_t)event, (kunlunStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -75,6 +79,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -75,6 +79,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_KUNLUNRT(xpu_malloc(p_ptr, static_cast<uint64_t>(size))); CHECK_KUNLUNRT(xpu_malloc(p_ptr, static_cast<uint64_t>(size)));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_MACART(hcEventRecord((hcEvent_t)event, (hcStream_t)stream)); CHECK_MACART(hcEventRecord((hcEvent_t)event, (hcStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -70,6 +74,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -70,6 +74,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_MACART(hcMalloc(p_ptr, size)); CHECK_MACART(hcMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_MUSART(musaEventRecord((musaEvent_t)event, (musaStream_t)stream)); CHECK_MUSART(musaEventRecord((musaEvent_t)event, (musaStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -77,6 +81,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -77,6 +81,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_MUSART(musaMalloc(p_ptr, size)); CHECK_MUSART(musaMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
import infinicore
import torch
def test_device_event_timing():
"""Test DeviceEvent for timing operations - using instance method API"""
print("\nTesting DeviceEvent timing...")
# Create events
start_event = infinicore.DeviceEvent(enable_timing=True)
end_event = infinicore.DeviceEvent(enable_timing=True)
# Create test tensors
shape = [1000, 1000]
device = infinicore.device("cuda", 0)
# Time tensor creation and operations
start_event.record()
# Perform some operations
t1 = infinicore.ones(shape, dtype=infinicore.float32, device=device)
t2 = infinicore.zeros(shape, dtype=infinicore.float32, device=device)
# Simulate some computation by multiple operations
for _ in range(10):
t1 = t1.permute([1, 0])
t2 = t2.permute([1, 0])
end_event.record()
# Wait for operations to complete
end_event.synchronize()
# Calculate elapsed time - USING INSTANCE METHOD (torch-compatible)
elapsed_time = start_event.elapsed_time(end_event)
print(f"✓ DeviceEvent timing test passed - Elapsed time: {elapsed_time:.3f} ms")
assert elapsed_time >= 0, "Elapsed time should be non-negative"
return elapsed_time
def test_device_event_query():
"""Test DeviceEvent query functionality"""
print("\nTesting DeviceEvent query...")
event = infinicore.DeviceEvent(enable_timing=True)
# Event should not be completed before recording
assert not event.is_recorded, "Event should not be recorded initially"
# Record the event
event.record()
assert event.is_recorded, "Event should be recorded after record()"
# Query completion (might be immediate for simple cases)
completed = event.query()
print(f"✓ DeviceEvent query test passed - Event completed: {completed}")
# Ensure synchronization works
event.synchronize()
assert event.query(), "Event should be completed after synchronize()"
def test_multiple_devices():
"""Test operations across multiple devices"""
print("\nTesting multiple devices...")
cuda_count = infinicore.get_device_count("cuda")
if cuda_count > 1:
# Test operations on different devices
shape = [100, 100]
# Create events for timing
event0_start = infinicore.DeviceEvent(
device=infinicore.device("cuda", 0), enable_timing=True
)
event0_end = infinicore.DeviceEvent(
device=infinicore.device("cuda", 0), enable_timing=True
)
event1_start = infinicore.DeviceEvent(
device=infinicore.device("cuda", 1), enable_timing=True
)
event1_end = infinicore.DeviceEvent(
device=infinicore.device("cuda", 1), enable_timing=True
)
# Create tensors on different devices
event0_start.record()
t_device0 = infinicore.ones(
shape, dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
event0_end.record()
event1_start.record()
t_device1 = infinicore.zeros(
shape, dtype=infinicore.float32, device=infinicore.device("cuda", 1)
)
event1_end.record()
# Synchronize both devices
event0_end.synchronize()
event1_end.synchronize()
# Calculate elapsed times
time_device0 = event0_start.elapsed_time(event0_end)
time_device1 = event1_start.elapsed_time(event1_end)
print(f"✓ Multiple devices test passed")
print(f" Device 0 tensor creation time: {time_device0:.3f} ms")
print(f" Device 1 tensor creation time: {time_device1:.3f} ms")
# Test operations timing
event0_start.record()
for _ in range(20):
t_device0 = t_device0.permute([1, 0])
event0_end.record()
event1_start.record()
for _ in range(20):
t_device1 = t_device1.permute([1, 0])
event1_end.record()
# Synchronize again
event0_end.synchronize()
event1_end.synchronize()
# Calculate operation times
op_time_device0 = event0_start.elapsed_time(event0_end)
op_time_device1 = event1_start.elapsed_time(event1_end)
print(f" Device 0 operations time: {op_time_device0:.3f} ms")
print(f" Device 1 operations time: {op_time_device1:.3f} ms")
# Test cross-device operations if supported
try:
# Try to create an event that measures cross-device operations
cross_start = infinicore.DeviceEvent(device=infinicore.device("cuda", 0))
cross_end = infinicore.DeviceEvent(device=infinicore.device("cuda", 0))
cross_start.record()
# Perform operations on both devices
for _ in range(10):
t_device0 = t_device0.permute([1, 0])
# Note: Actual cross-device operations would require explicit synchronization
cross_end.record()
cross_end.synchronize()
cross_time = cross_start.elapsed_time(cross_end)
print(f" Cross-device operations time: {cross_time:.3f} ms")
except Exception as e:
print(f" Cross-device timing skipped: {e}")
else:
print("⚠ Skipping multiple devices test (only 1 CUDA device available)")
def test_event_stream():
"""Test DeviceEvent with different streams"""
print("\nTesting DeviceEvent with streams...")
try:
# Get default stream
default_stream = None
if hasattr(infinicore, "get_stream"):
default_stream = infinicore.get_stream()
else:
print("⚠ infinicore.get_stream() not available, using default stream")
# Create event and record
event = infinicore.DeviceEvent(enable_timing=True)
if default_stream is not None:
event.record(stream=default_stream)
else:
event.record()
event.synchronize()
print("✓ DeviceEvent stream test passed")
except Exception as e:
print(f"⚠ DeviceEvent stream test skipped: {e}")
def test_concurrent_events():
"""Test multiple concurrent events"""
print("\nTesting concurrent events...")
# Create multiple events
events = []
for i in range(5):
events.append(infinicore.DeviceEvent(enable_timing=True))
# Record events with small delays
for i, event in enumerate(events):
event.record()
# Small operation
temp = infinicore.ones(
[10, 10], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
temp = temp.permute([1, 0])
# Synchronize all events
for event in events:
event.synchronize()
assert event.query(), "All events should be completed"
print("✓ Concurrent events test passed")
def test_torch_style_usage():
"""Test that our API matches torch.cuda.Event usage pattern"""
print("\nTesting torch.cuda.Event style usage...")
# This should work exactly like torch.cuda.Event
start = infinicore.DeviceEvent(enable_timing=True)
end = infinicore.DeviceEvent(enable_timing=True)
# Record events
start.record()
# Some operations
tensor = infinicore.ones(
[100, 100], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(5):
tensor = tensor.permute([1, 0])
end.record()
end.synchronize()
# This is the torch-compatible API
time_taken = start.elapsed_time(end)
print(f"✓ Torch-style usage test passed - Time: {time_taken:.3f} ms")
def test_event_synchronization():
"""Test event synchronization behavior"""
print("\nTesting event synchronization...")
event1 = infinicore.DeviceEvent(enable_timing=True)
event2 = infinicore.DeviceEvent(enable_timing=True)
# Record events in sequence
event1.record()
# Some work
temp = infinicore.zeros(
[50, 50], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
event2.record()
# event2 should complete after event1
event2.synchronize()
assert event2.query(), "event2 should be completed"
assert event1.query(), "event1 should also be completed after event2 sync"
print("✓ Event synchronization test passed")
def test_event_wait_functionality():
"""Test the wait functionality of DeviceEvent"""
print("\nTesting DeviceEvent wait functionality...")
# Create events
event1 = infinicore.DeviceEvent(enable_timing=True)
event2 = infinicore.DeviceEvent(enable_timing=True)
# Record first event
event1.record()
# Perform some work
tensor1 = infinicore.ones(
[500, 500], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(10):
tensor1 = tensor1.permute([1, 0])
# Record second event
event2.record()
# Make event2 wait for event1 using wait() method
event2.wait()
# Both events should be completed now
assert event1.query(), "event1 should be completed"
assert event2.query(), "event2 should be completed after waiting"
print("✓ Event wait functionality test passed")
def test_stream_wait_event():
"""Test stream waiting for events"""
print("\nTesting stream wait event functionality...")
try:
# Get the current stream
current_stream = infinicore.get_stream()
# Create events
dependency_event = infinicore.DeviceEvent(enable_timing=True)
dependent_event = infinicore.DeviceEvent(enable_timing=True)
# Record dependency event
dependency_event.record()
# Perform some work that creates a dependency
tensor = infinicore.ones(
[300, 300], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(5):
tensor = tensor.permute([1, 0])
# Make the stream wait for the dependency event before recording dependent event
dependency_event.wait(current_stream)
# Record dependent event after the wait
dependent_event.record()
# Synchronize and verify
dependent_event.synchronize()
assert dependency_event.query(), "Dependency event should be completed"
assert dependent_event.query(), "Dependent event should be completed"
print("✓ Stream wait event test passed")
except Exception as e:
print(f"⚠ Stream wait event test skipped: {e}")
def test_multiple_stream_synchronization():
"""Test event-based synchronization between multiple streams"""
print("\nTesting multiple stream synchronization...")
try:
# This test simulates a producer-consumer pattern using events
producer_event = infinicore.DeviceEvent(enable_timing=True)
consumer_event = infinicore.DeviceEvent(enable_timing=True)
# Producer work
producer_event.record()
# Simulate producer work (data generation)
data_tensor = infinicore.ones(
[200, 200], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(8):
data_tensor = data_tensor.permute([1, 0])
# Make consumer wait for producer to finish
producer_event.wait() # Wait on current stream
# Consumer work (depends on producer's output)
processed_tensor = data_tensor.permute([1, 0]) # Consumer operation
consumer_event.record()
# Verify the synchronization worked
consumer_event.synchronize()
assert producer_event.query(), "Producer event should be completed"
assert consumer_event.query(), "Consumer event should be completed"
print("✓ Multiple stream synchronization test passed")
except Exception as e:
print(f"⚠ Multiple stream synchronization test skipped: {e}")
def test_event_wait_with_specific_stream():
"""Test waiting on specific streams"""
print("\nTesting event wait with specific streams...")
try:
# Get current stream
main_stream = infinicore.get_stream()
# Create events
compute_event = infinicore.DeviceEvent(enable_timing=True)
transfer_event = infinicore.DeviceEvent(enable_timing=True)
# Record compute event after some computation
compute_event.record()
# Simulate computation
compute_tensor = infinicore.ones(
[150, 150], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(6):
compute_tensor = compute_tensor.permute([1, 0])
# Make data transfer wait for computation to complete
compute_event.wait(main_stream)
# Record transfer event
transfer_event.record()
# Verify synchronization
transfer_event.synchronize()
assert compute_event.query(), "Compute event should be completed"
assert transfer_event.query(), "Transfer event should be completed"
print("✓ Event wait with specific stream test passed")
except Exception as e:
print(f"⚠ Event wait with specific stream test skipped: {e}")
def test_complex_dependency_chain():
"""Test complex dependency chains using events"""
print("\nTesting complex dependency chains...")
try:
# Create multiple events for a dependency chain
event_a = infinicore.DeviceEvent(enable_timing=True)
event_b = infinicore.DeviceEvent(enable_timing=True)
event_c = infinicore.DeviceEvent(enable_timing=True)
event_d = infinicore.DeviceEvent(enable_timing=True)
# Stage A
event_a.record()
tensor_a = infinicore.ones(
[100, 100], dtype=infinicore.float32, device=infinicore.device("cuda", 0)
)
for _ in range(3):
tensor_a = tensor_a.permute([1, 0])
# Stage B depends on A
event_a.wait()
event_b.record()
tensor_b = tensor_a.permute([1, 0]) # Depends on tensor_a
for _ in range(3):
tensor_b = tensor_b.permute([1, 0])
# Stage C depends on B
event_b.wait()
event_c.record()
tensor_c = tensor_b.permute([1, 0]) # Depends on tensor_b
for _ in range(3):
tensor_c = tensor_c.permute([1, 0])
# Stage D depends on C
event_c.wait()
event_d.record()
tensor_d = tensor_c.permute([1, 0]) # Depends on tensor_c
# Final synchronization
event_d.synchronize()
# Verify all events completed in order
assert event_a.query(), "Event A should be completed"
assert event_b.query(), "Event B should be completed"
assert event_c.query(), "Event C should be completed"
assert event_d.query(), "Event D should be completed"
print("✓ Complex dependency chain test passed")
except Exception as e:
print(f"⚠ Complex dependency chain test skipped: {e}")
def test_wait_before_record():
"""Test waiting for an event that hasn't been recorded yet"""
print("\nTesting wait before record behavior...")
try:
event = infinicore.DeviceEvent(enable_timing=True)
# This should not crash, but the behavior depends on the underlying implementation
# In most systems, waiting for an unrecorded event is undefined behavior
# We're testing that our API handles this gracefully
event.wait()
print(
"✓ Wait before record test completed (behavior may vary by implementation)"
)
except Exception as e:
print(f"⚠ Wait before record test encountered expected behavior: {e}")
def run_all_tests():
"""Run all device-related tests"""
print("Starting DeviceEvent and device tests...")
print("=" * 50)
try:
# Basic functionality tests
test_device_event_timing()
test_device_event_query()
test_torch_style_usage()
test_event_synchronization()
test_concurrent_events()
# Wait functionality tests (new)
test_event_wait_functionality()
test_stream_wait_event()
test_multiple_stream_synchronization()
test_event_wait_with_specific_stream()
test_complex_dependency_chain()
test_wait_before_record()
# Optional tests (may depend on system capabilities)
test_multiple_devices()
test_event_flags()
test_event_stream()
print("\n" + "=" * 50)
print("🎉 All tests passed successfully!")
print("DeviceEvent wait functionality is working correctly!")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
run_all_tests()
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
get_args,
get_hardware_args_group,
get_test_devices,
)
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
from .utils import ( from .utils import (
compare_results, compare_results,
...@@ -6,21 +15,12 @@ from .utils import ( ...@@ -6,21 +15,12 @@ from .utils import (
debug, debug,
get_tolerance, get_tolerance,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
profile_operation,
rearrange_tensor, rearrange_tensor,
convert_infinicore_to_torch, convert_infinicore_to_torch,
is_integer_dtype, is_integer_dtype,
is_complex_dtype, is_complex_dtype,
is_floating_dtype, is_floating_dtype,
) )
from .config import (
get_args,
get_hardware_args_group,
get_test_devices,
)
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .runner import GenericTestRunner
__all__ = [ __all__ = [
# Core types and classes # Core types and classes
...@@ -43,7 +43,6 @@ __all__ = [ ...@@ -43,7 +43,6 @@ __all__ = [
"get_test_devices", "get_test_devices",
"get_tolerance", "get_tolerance",
"infinicore_tensor_from_torch", "infinicore_tensor_from_torch",
"profile_operation",
"rearrange_tensor", "rearrange_tensor",
# Utility functions # Utility functions
"to_infinicore_dtype", "to_infinicore_dtype",
...@@ -53,4 +52,7 @@ __all__ = [ ...@@ -53,4 +52,7 @@ __all__ = [
"is_integer_dtype", "is_integer_dtype",
"is_complex_dtype", "is_complex_dtype",
"is_floating_dtype", "is_floating_dtype",
# Benchmarking utilities
"BenchmarkUtils",
"BenchmarkResult",
] ]
...@@ -11,8 +11,8 @@ from .tensor import TensorSpec, TensorInitializer ...@@ -11,8 +11,8 @@ from .tensor import TensorSpec, TensorInitializer
from .utils import ( from .utils import (
create_test_comparator, create_test_comparator,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
profile_operation,
) )
from .benchmark import BenchmarkUtils
@dataclass @dataclass
...@@ -21,8 +21,10 @@ class TestResult: ...@@ -21,8 +21,10 @@ class TestResult:
success: bool success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_time: float = 0.0 torch_host_time: float = 0.0
infini_time: float = 0.0 torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = "" error_message: str = ""
test_case: Any = None test_case: Any = None
device: Any = None device: Any = None
...@@ -202,8 +204,10 @@ class TestRunner: ...@@ -202,8 +204,10 @@ class TestRunner:
) # Track passed tests (both operators implemented and passed) ) # Track passed tests (both operators implemented and passed)
# Add benchmark timing statistics # Add benchmark timing statistics
self.benchmark_times = { self.benchmark_times = {
"torch_total": 0.0, "torch_host_total": 0.0,
"infinicore_total": 0.0, "torch_device_total": 0.0,
"infinicore_host_total": 0.0,
"infinicore_device_total": 0.0,
"per_test_case": {}, # Store timing per test case "per_test_case": {}, # Store timing per test case
} }
# Store test results # Store test results
...@@ -329,8 +333,10 @@ class TestRunner: ...@@ -329,8 +333,10 @@ class TestRunner:
# Print benchmark summary if benchmarking was enabled # Print benchmark summary if benchmarking was enabled
if self.config.bench and ( if self.config.bench and (
self.benchmark_times["torch_total"] > 0 self.benchmark_times["torch_host_total"] > 0
or self.benchmark_times["infinicore_total"] > 0 or self.benchmark_times["torch_device_total"] > 0
or self.benchmark_times["infinicore_host_total"] > 0
or self.benchmark_times["infinicore_device_total"] > 0
): ):
self._print_benchmark_summary() self._print_benchmark_summary()
...@@ -342,19 +348,28 @@ class TestRunner: ...@@ -342,19 +348,28 @@ class TestRunner:
print(f"{'-'*60}") print(f"{'-'*60}")
print("BENCHMARK SUMMARY") print("BENCHMARK SUMMARY")
torch_total = self.benchmark_times["torch_total"] torch_host_total = self.benchmark_times["torch_host_total"]
infinicore_total = self.benchmark_times["infinicore_total"] torch_device_total = self.benchmark_times["torch_device_total"]
infinicore_host_total = self.benchmark_times["infinicore_host_total"]
if torch_total > 0: infinicore_device_total = self.benchmark_times["infinicore_device_total"]
print(f"PyTorch Total Time: {torch_total * 1000:.3f} ms")
if infinicore_total > 0: if torch_host_total > 0:
print(f"InfiniCore Total Time: {infinicore_total * 1000:.3f} ms") print(f"PyTorch Host Total Time: {torch_host_total:.3f} ms")
if torch_device_total > 0:
if torch_total > 0 and infinicore_total > 0: print(f"PyTorch Device Total Time: {torch_device_total:.3f} ms")
speedup = ( if infinicore_host_total > 0:
torch_total / infinicore_total if infinicore_total > 0 else float("inf") print(f"InfiniCore Host Total Time: {infinicore_host_total:.3f} ms")
) if infinicore_device_total > 0:
print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x") print(f"InfiniCore Device Total Time: {infinicore_device_total:.3f} ms")
# Calculate speedups
if torch_host_total > 0 and infinicore_host_total > 0:
host_speedup = torch_host_total / infinicore_host_total
print(f"Host Speedup (PyTorch/InfiniCore): {host_speedup:.2f}x")
if torch_device_total > 0 and infinicore_device_total > 0:
device_speedup = torch_device_total / infinicore_device_total
print(f"Device Speedup (PyTorch/InfiniCore): {device_speedup:.2f}x")
def get_test_results(self): def get_test_results(self):
"""Get all test results""" """Get all test results"""
...@@ -593,20 +608,27 @@ class BaseOperatorTest(ABC): ...@@ -593,20 +608,27 @@ class BaseOperatorTest(ABC):
test_result.return_code = -3 # Partial test_result.return_code = -3 # Partial
# Run benchmarking for partial tests if enabled # Run benchmarking for partial tests if enabled
if config.bench: if config.bench:
torch_time, infini_time = self._run_benchmarking( torch_host, torch_device, infini_host, infini_device = (
config, BenchmarkUtils.run_benchmarking(
device_str, config,
torch_implemented, device_str,
infini_implemented, torch_implemented,
inputs, infini_implemented,
kwargs, self.torch_operator,
infini_inputs, self.infinicore_operator,
infini_kwargs, inputs,
test_case.output_count, kwargs,
comparison_target, infini_inputs,
infini_kwargs,
test_case.output_count,
comparison_target,
bench_mode=config.bench,
)
) )
test_result.torch_time = torch_time test_result.torch_host_time = torch_host
test_result.infini_time = infini_time test_result.torch_device_time = torch_device
test_result.infini_host_time = infini_host
test_result.infini_device_time = infini_device
return test_result return test_result
# ========================================================================== # ==========================================================================
# MULTIPLE OUTPUTS COMPARISON LOGIC # MULTIPLE OUTPUTS COMPARISON LOGIC
...@@ -716,109 +738,43 @@ class BaseOperatorTest(ABC): ...@@ -716,109 +738,43 @@ class BaseOperatorTest(ABC):
# UNIFIED BENCHMARKING LOGIC # UNIFIED BENCHMARKING LOGIC
# ========================================================================== # ==========================================================================
if config.bench: if config.bench:
torch_time, infini_time = self._run_benchmarking( torch_host, torch_device, infini_host, infini_device = (
config, BenchmarkUtils.run_benchmarking(
device_str, config,
True, device_str,
True, True,
inputs, True,
kwargs, self.torch_operator,
infini_inputs, self.infinicore_operator,
infini_kwargs, inputs,
test_case.output_count, kwargs,
comparison_target, infini_inputs,
infini_kwargs,
test_case.output_count,
comparison_target,
bench_mode=config.bench,
)
) )
test_result.torch_time = torch_time test_result.torch_host_time = torch_host
test_result.infini_time = infini_time test_result.torch_device_time = torch_device
test_result.infini_host_time = infini_host
test_result.infini_device_time = infini_device
# Store timing information in the test runner
if hasattr(config, "_test_runner") and config._test_runner:
# Accumulate total times
config._test_runner.benchmark_times["torch_host_total"] += torch_host
config._test_runner.benchmark_times[
"torch_device_total"
] += torch_device
config._test_runner.benchmark_times[
"infinicore_host_total"
] += infini_host
config._test_runner.benchmark_times[
"infinicore_device_total"
] += infini_device
# Test passed successfully # Test passed successfully
test_result.success = True test_result.success = True
test_result.return_code = 0 test_result.return_code = 0
return test_result return test_result
def _run_benchmarking(
self,
config,
device_str,
torch_implemented,
infini_implemented,
inputs,
kwargs,
infini_inputs,
infini_kwargs,
output_count,
comparison_target,
):
"""
Unified benchmarking logic with timing accumulation
Returns:
tuple: (torch_time, infini_time) timing results
"""
# Initialize timing variables
torch_time = 0.0
infini_time = 0.0
if torch_implemented:
if output_count > 1:
# For multiple outputs, just call the operator
def torch_op():
return self.torch_operator(*inputs, **kwargs)
else:
if comparison_target is None:
# Out-of-place benchmarking
def torch_op():
return self.torch_operator(*inputs, **kwargs)
else:
# In-place benchmarking
def torch_op():
self.torch_operator(*inputs, **kwargs)
return (
kwargs.get("out")
if "out" in kwargs
else inputs[comparison_target]
)
torch_time = profile_operation(
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
total=True,
)
if infini_implemented:
if comparison_target is None:
# Out-of-place benchmarking
def infini_op():
return self.infinicore_operator(*infini_inputs, **infini_kwargs)
else:
# In-place benchmarking
def infini_op():
self.infinicore_operator(*infini_inputs, **infini_kwargs)
return (
infini_kwargs.get("out")
if "out" in infini_kwargs
else infini_inputs[comparison_target]
)
infini_time = profile_operation(
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
total=True,
)
# Store timing information in the test runner
if hasattr(config, "_test_runner") and config._test_runner:
# Accumulate total times
config._test_runner.benchmark_times["torch_total"] += torch_time
config._test_runner.benchmark_times["infinicore_total"] += infini_time
return torch_time, infini_time
"""
Benchmarking utilities for the InfiniCore testing framework
"""
import time
import torch
import infinicore
from .utils import synchronize_device
class BenchmarkUtils:
"""Utility class for benchmarking operations"""
@staticmethod
def profile_operation(
desc,
func,
torch_device,
num_prerun,
num_iterations,
host_time=True,
device_time=True,
total=False,
):
"""
Performance profiling workflow with both host and device timing
Args:
desc: Operation description for display
func: Function to profile
torch_device: Torch device string
num_prerun: Number of warm-up runs
num_iterations: Number of iterations for timing
host_time: Whether to measure host (CPU) time
device_time: Whether to measure device time
total: Whether to return total time instead of per-iteration time
Returns:
tuple: (host_time, device_time) timing results
"""
# Timed execution
host_elapsed = 0.0
device_elapsed = 0.0
if host_time:
# Warm-up runs
for _ in range(num_prerun):
func()
host_elapsed = BenchmarkUtils.timed_op_host(
func, num_iterations, torch_device
)
if device_time:
# Warm-up runs
for _ in range(num_prerun):
func()
device_elapsed = BenchmarkUtils.timed_op_device(
func, num_iterations, torch_device
)
# Print results
if host_time and device_time:
print(
f" {desc} time - Host: {host_elapsed / num_iterations :6f} ms, "
f"Device: {device_elapsed / num_iterations :6f} ms"
)
elif host_time:
print(f" {desc} time - Host: {host_elapsed / num_iterations :6f} ms")
elif device_time:
print(f" {desc} time - Device: {device_elapsed / num_iterations :6f} ms")
if total:
return host_elapsed, device_elapsed
else:
return host_elapsed / num_iterations, device_elapsed / num_iterations
@staticmethod
def timed_op_host(func, num_iterations, device):
"""
Execute function multiple times and measure total host execution time
Args:
func: Function to execute
num_iterations: Number of iterations
device: Torch device string for synchronization
Returns:
float: Total host execution time in seconds
"""
synchronize_device(device)
start = time.time()
for _ in range(num_iterations):
func()
synchronize_device(device)
return (time.time() - start) * 1000.0
@staticmethod
def timed_op_device(func, num_iterations, device):
"""
Execute function multiple times and measure device execution time using DeviceEvent pairs
Args:
func: Function to execute
num_iterations: Number of iterations
device: Torch device string for synchronization
Returns:
float: Total device execution time in milliseconds
"""
# Only use DeviceEvent for GPU devices
if device in ["cpu"]:
return 0.0
def _clear_cache():
pass
if infinicore.use_ntops:
import triton
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
def _clear_cache():
triton.runtime.driver.active.clear_cache(cache)
# Create pairs of DeviceEvents for each iteration
start_events = [
infinicore.DeviceEvent(enable_timing=True) for _ in range(num_iterations)
]
end_events = [
infinicore.DeviceEvent(enable_timing=True) for _ in range(num_iterations)
]
# Execute the function multiple times with individual timing
for i in range(num_iterations):
_clear_cache()
start_events[i].record()
func()
end_events[i].record()
# Synchronize all end events
for event in end_events:
event.synchronize()
# Calculate total elapsed time by summing individual iteration times
total_device_time = 0.0
for i in range(num_iterations):
total_device_time += start_events[i].elapsed_time(end_events[i])
return total_device_time
@staticmethod
def run_benchmarking(
config,
device_str,
torch_implemented,
infini_implemented,
torch_operator,
infini_operator,
inputs,
kwargs,
infini_inputs,
infini_kwargs,
output_count,
comparison_target,
bench_mode="both",
):
"""
Unified benchmarking logic with timing accumulation
Args:
config: Test configuration
device_str: Torch device string
torch_implemented: Whether PyTorch operator is implemented
infini_implemented: Whether InfiniCore operator is implemented
torch_operator: PyTorch operator function
infini_operator: InfiniCore operator function
inputs: PyTorch operator inputs
kwargs: PyTorch operator keyword arguments
infini_inputs: InfiniCore operator inputs
infini_kwargs: InfiniCore operator keyword arguments
output_count: Number of outputs
comparison_target: Comparison target specification
bench_mode: Benchmark mode - "host", "device", or "both"
Returns:
tuple: (torch_host_time, torch_device_time, infini_host_time, infini_device_time)
"""
# Determine what to time based on bench_mode
host_time = bench_mode in ["host", "both"]
device_time = bench_mode in ["device", "both"]
# Initialize timing variables
torch_host_time = 0.0
torch_device_time = 0.0
infini_host_time = 0.0
infini_device_time = 0.0
if torch_implemented:
if output_count > 1:
# For multiple outputs, just call the operator
def torch_op():
return torch_operator(*inputs, **kwargs)
else:
if comparison_target is None:
# Out-of-place benchmarking
def torch_op():
return torch_operator(*inputs, **kwargs)
else:
# In-place benchmarking
def torch_op():
torch_operator(*inputs, **kwargs)
return (
kwargs.get("out")
if "out" in kwargs
else inputs[comparison_target]
)
torch_host, torch_device = BenchmarkUtils.profile_operation(
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
host_time=host_time,
device_time=device_time,
total=True,
)
torch_host_time = torch_host
torch_device_time = torch_device
if infini_implemented:
if comparison_target is None:
# Out-of-place benchmarking
def infini_op():
return infini_operator(*infini_inputs, **infini_kwargs)
else:
# In-place benchmarking
def infini_op():
infini_operator(*infini_inputs, **infini_kwargs)
return (
infini_kwargs.get("out")
if "out" in infini_kwargs
else infini_inputs[comparison_target]
)
infini_host, infini_device = BenchmarkUtils.profile_operation(
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
host_time=host_time,
device_time=device_time,
total=True,
)
infini_host_time = infini_host
infini_device_time = infini_device
return torch_host_time, torch_device_time, infini_host_time, infini_device_time
class BenchmarkResult:
"""Container for benchmark results"""
def __init__(self):
self.torch_host_total = 0.0
self.torch_device_total = 0.0
self.infinicore_host_total = 0.0
self.infinicore_device_total = 0.0
self.per_test_case = {}
def add_timing(
self, test_case_name, torch_host, torch_device, infini_host, infini_device
):
"""Add timing for a specific test case"""
self.per_test_case[test_case_name] = {
"torch_host_time": torch_host,
"torch_device_time": torch_device,
"infini_host_time": infini_host,
"infini_device_time": infini_device,
}
self.torch_host_total += torch_host
self.torch_device_total += torch_device
self.infinicore_host_total += infini_host
self.infinicore_device_total += infini_device
def get_host_speedup(self):
"""Calculate host speedup ratio"""
if self.infinicore_host_total > 0:
return self.torch_host_total / self.infinicore_host_total
return float("inf")
def get_device_speedup(self):
"""Calculate device speedup ratio"""
if self.infinicore_device_total > 0:
return self.torch_device_total / self.infinicore_device_total
return float("inf")
...@@ -54,9 +54,15 @@ Examples: ...@@ -54,9 +54,15 @@ Examples:
# Run all tests on CPU only # Run all tests on CPU only
python test_operator.py --cpu python test_operator.py --cpu
# Run with benchmarking on NVIDIA GPU # Run with benchmarking on NVIDIA GPU (both host and device timing)
python test_operator.py --nvidia --bench python test_operator.py --nvidia --bench
# Run with benchmarking - host timing only
python test_operator.py --nvidia --bench host
# Run with benchmarking - device timing only
python test_operator.py --nvidia --bench device
# Run with debug mode on multiple devices # Run with debug mode on multiple devices
python test_operator.py --cpu --nvidia --debug python test_operator.py --cpu --nvidia --debug
...@@ -72,8 +78,11 @@ Examples: ...@@ -72,8 +78,11 @@ Examples:
# Core testing options # Core testing options
parser.add_argument( parser.add_argument(
"--bench", "--bench",
action="store_true", nargs="?",
help="Enable performance benchmarking mode", const="both",
choices=["host", "device", "both"],
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
) )
parser.add_argument( parser.add_argument(
"--num_prerun", "--num_prerun",
......
...@@ -15,35 +15,6 @@ def synchronize_device(torch_device): ...@@ -15,35 +15,6 @@ def synchronize_device(torch_device):
torch.mlu.synchronize() torch.mlu.synchronize()
def timed_op(func, num_iterations, device):
"""Timed operation"""
synchronize_device(device)
start = time.time()
for _ in range(num_iterations):
func()
synchronize_device(device)
return time.time() - start
def profile_operation(
desc, func, torch_device, num_prerun, num_iterations, total=False
):
"""
Performance profiling workflow
"""
# Warm-up runs
for _ in range(num_prerun):
func()
# Timed execution
elapsed = timed_op(lambda: func(), num_iterations, torch_device)
print(f" {desc} time: {elapsed / num_iterations * 1000 :6f} ms")
if total:
return elapsed
else:
return elapsed / num_iterations
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
""" """
Debug function to compare two tensors and print differences Debug function to compare two tensors and print differences
......
...@@ -109,14 +109,18 @@ def import_operator_test(test_file_path): ...@@ -109,14 +109,18 @@ def import_operator_test(test_file_path):
return False, f"Error importing {test_file_path}: {str(e)}" return False, f"Error importing {test_file_path}: {str(e)}"
def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False): def run_all_op_tests(
ops_dir=None, specific_ops=None, bench=False, bench_mode="both", verbose=False
):
""" """
Run all operator test scripts in the ops directory using direct import. Run all operator test scripts in the ops directory using direct import.
Args: Args:
ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection. ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection.
specific_ops (list, optional): List of specific operator names to test. specific_ops (list, optional): List of specific operator names to test.
extra_args (list, optional): Extra command line arguments to pass to test scripts. bench (bool): Whether benchmarking is enabled
bench_mode (str): Benchmark mode - "host", "device", or "both"
verbose (bool): Whether verbose mode is enabled
Returns: Returns:
dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values. dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values.
...@@ -174,8 +178,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -174,8 +178,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
results = {} results = {}
cumulative_timing = { cumulative_timing = {
"total_torch_time": 0.0, "total_torch_host_time": 0.0,
"total_infinicore_time": 0.0, "total_torch_device_time": 0.0,
"total_infinicore_host_time": 0.0,
"total_infinicore_device_time": 0.0,
"operators_tested": 0, "operators_tested": 0,
} }
...@@ -191,8 +197,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -191,8 +197,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
results[test_name] = { results[test_name] = {
"success": False, "success": False,
"return_code": -1, "return_code": -1,
"torch_time": 0.0, "torch_host_time": 0.0,
"infini_time": 0.0, "torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": test_instance_or_error, "error_message": test_instance_or_error,
"test_runner": None, "test_runner": None,
"stdout": "", "stdout": "",
...@@ -207,8 +215,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -207,8 +215,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
results[test_name] = { results[test_name] = {
"success": False, "success": False,
"return_code": -1, "return_code": -1,
"torch_time": 0.0, "torch_host_time": 0.0,
"infini_time": 0.0, "torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": "No GenericTestRunner found", "error_message": "No GenericTestRunner found",
"test_runner": None, "test_runner": None,
"stdout": "", "stdout": "",
...@@ -287,15 +297,25 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -287,15 +297,25 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
status_icon = "❌" status_icon = "❌"
status_text = "FAILED" status_text = "FAILED"
# Calculate timing # Calculate timing for all four metrics
torch_time = sum(result.torch_time for result in test_results) torch_host_time = sum(result.torch_host_time for result in test_results)
infini_time = sum(result.infini_time for result in test_results) torch_device_time = sum(
result.torch_device_time for result in test_results
)
infini_host_time = sum(
result.infini_host_time for result in test_results
)
infini_device_time = sum(
result.infini_device_time for result in test_results
)
results[test_name] = { results[test_name] = {
"success": test_success, "success": test_success,
"return_code": return_code, "return_code": return_code,
"torch_time": torch_time, "torch_host_time": torch_host_time,
"infini_time": infini_time, "torch_device_time": torch_device_time,
"infini_host_time": infini_host_time,
"infini_device_time": infini_device_time,
"error_message": "", "error_message": "",
"test_runner": test_runner, "test_runner": test_runner,
"stdout": stdout_output, "stdout": stdout_output,
...@@ -308,8 +328,12 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -308,8 +328,12 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
# Extract benchmark timing if in bench mode # Extract benchmark timing if in bench mode
if bench and test_success and return_code == 0: if bench and test_success and return_code == 0:
cumulative_timing["total_torch_time"] += torch_time cumulative_timing["total_torch_host_time"] += torch_host_time
cumulative_timing["total_infinicore_time"] += infini_time cumulative_timing["total_torch_device_time"] += torch_device_time
cumulative_timing["total_infinicore_host_time"] += infini_host_time
cumulative_timing[
"total_infinicore_device_time"
] += infini_device_time
cumulative_timing["operators_tested"] += 1 cumulative_timing["operators_tested"] += 1
except Exception as e: except Exception as e:
...@@ -327,8 +351,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -327,8 +351,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
results[test_name] = { results[test_name] = {
"success": False, "success": False,
"return_code": -1, "return_code": -1,
"torch_time": 0.0, "torch_host_time": 0.0,
"infini_time": 0.0, "torch_device_time": 0.0,
"infini_host_time": 0.0,
"infini_device_time": 0.0,
"error_message": str(e), "error_message": str(e),
"test_runner": None, "test_runner": None,
"stdout": "", "stdout": "",
...@@ -348,7 +374,11 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False ...@@ -348,7 +374,11 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False
def print_summary( def print_summary(
results, verbose=False, total_expected_tests=0, cumulative_timing=None results,
verbose=False,
total_expected_tests=0,
cumulative_timing=None,
bench_mode="both",
): ):
"""Print a comprehensive summary of test results including benchmark data.""" """Print a comprehensive summary of test results including benchmark data."""
print(f"\n{'='*80}") print(f"\n{'='*80}")
...@@ -405,12 +435,24 @@ def print_summary( ...@@ -405,12 +435,24 @@ def print_summary(
print(f"{'-'*40}") print(f"{'-'*40}")
print("BENCHMARK SUMMARY:") print("BENCHMARK SUMMARY:")
print(f" Operators Tested: {cumulative_timing['operators_tested']}") print(f" Operators Tested: {cumulative_timing['operators_tested']}")
print(
f" PyTorch Total Time: {cumulative_timing['total_torch_time'] * 1000:12.3f} ms" # Display timing based on bench_mode
) if bench_mode in ["host", "both"]:
print( print(
f" InfiniCore Total Time: {cumulative_timing['total_infinicore_time'] * 1000:12.3f} ms" f" PyTorch Host Total Time: {cumulative_timing['total_torch_host_time']:12.3f} ms"
) )
print(
f" InfiniCore Host Total Time: {cumulative_timing['total_infinicore_host_time']:12.3f} ms"
)
if bench_mode in ["device", "both"]:
print(
f" PyTorch Device Total Time: {cumulative_timing['total_torch_device_time']:12.3f} ms"
)
print(
f" InfiniCore Device Total Time: {cumulative_timing['total_infinicore_device_time']:12.3f} ms"
)
print(f"{'-'*40}") print(f"{'-'*40}")
# Display passed operators # Display passed operators
...@@ -528,9 +570,15 @@ def generate_help_epilog(ops_dir): ...@@ -528,9 +570,15 @@ def generate_help_epilog(ops_dir):
) )
epilog_parts.append(" python run.py --cpu --nvidia --verbose") epilog_parts.append(" python run.py --cpu --nvidia --verbose")
epilog_parts.append("") epilog_parts.append("")
epilog_parts.append(" # Run with benchmarking to get cumulative timing") epilog_parts.append(" # Run with benchmarking (both host and device timing)")
epilog_parts.append(" python run.py --cpu --bench") epilog_parts.append(" python run.py --cpu --bench")
epilog_parts.append("") epilog_parts.append("")
epilog_parts.append(" # Run with host timing only")
epilog_parts.append(" python run.py --nvidia --bench host")
epilog_parts.append("")
epilog_parts.append(" # Run with device timing only")
epilog_parts.append(" python run.py --nvidia --bench device")
epilog_parts.append("")
epilog_parts.append(" # List available tests without running") epilog_parts.append(" # List available tests without running")
epilog_parts.append(" python run.py --list") epilog_parts.append(" python run.py --list")
epilog_parts.append("") epilog_parts.append("")
...@@ -560,10 +608,10 @@ def generate_help_epilog(ops_dir): ...@@ -560,10 +608,10 @@ def generate_help_epilog(ops_dir):
" - --bench mode now shows cumulative timing across all operators" " - --bench mode now shows cumulative timing across all operators"
) )
epilog_parts.append( epilog_parts.append(
" - --verbose mode stops execution on first error and shows full traceback" " - --bench host/device/both controls host/device timing measurement"
) )
epilog_parts.append( epilog_parts.append(
" - In verbose mode, subsequent tests are skipped after first failure" " - --verbose mode stops execution on first error and shows full traceback"
) )
return "\n".join(epilog_parts) return "\n".join(epilog_parts)
...@@ -599,8 +647,11 @@ def main(): ...@@ -599,8 +647,11 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--bench", "--bench",
action="store_true", nargs="?",
help="Enable bench mode to show performance data", const="both",
choices=["host", "device", "both"],
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
) )
get_hardware_args_group(parser) get_hardware_args_group(parser)
...@@ -641,6 +692,10 @@ def main(): ...@@ -641,6 +692,10 @@ def main():
if args.verbose: if args.verbose:
print(f"Verbose mode: ENABLED (will stop on first error with full traceback)") print(f"Verbose mode: ENABLED (will stop on first error with full traceback)")
if args.bench:
bench_mode = args.bench if args.bench != "both" else "both"
print(f"Benchmark mode: {bench_mode.upper()} timing")
if args.ops: if args.ops:
# Validate requested operators # Validate requested operators
valid_ops = [] valid_ops = []
...@@ -671,13 +726,18 @@ def main(): ...@@ -671,13 +726,18 @@ def main():
results, cumulative_timing = run_all_op_tests( results, cumulative_timing = run_all_op_tests(
ops_dir=ops_dir, ops_dir=ops_dir,
specific_ops=args.ops, specific_ops=args.ops,
bench=args.bench, bench=bool(args.bench),
bench_mode=args.bench if args.bench else "both",
verbose=args.verbose, verbose=args.verbose,
) )
# Print summary and exit with appropriate code # Print summary and exit with appropriate code
all_passed = print_summary( all_passed = print_summary(
results, args.verbose, total_expected_tests, cumulative_timing results,
args.verbose,
total_expected_tests,
cumulative_timing,
bench_mode=args.bench if args.bench else "both",
) )
# Check if there were any tests with missing implementations # Check if there were any tests with missing implementations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment