Unverified Commit 79b70e58 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/571 DeviceEvent (#583)



* issue/571 - introduced the DeviceEvent feature

---------
Co-authored-by: default avatarJiacheng Huang <huangjiacheng0709@outlook.com>
parent d4738a98
#pragma once #pragma once
#include "infinicore/device_event.hpp"
#include "infinicore/nn.hpp" #include "infinicore/nn.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
...@@ -30,6 +30,16 @@ void memcpyD2H(void *dst, const void *src, size_t size); ...@@ -30,6 +30,16 @@ void memcpyD2H(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size); void memcpyD2D(void *dst, const void *src, size_t size);
void memcpyH2H(void *dst, const void *src, size_t size); void memcpyH2H(void *dst, const void *src, size_t size);
// Timing APIs for performance measurement
infinirtEvent_t createEvent();
infinirtEvent_t createEventWithFlags(uint32_t flags);
void recordEvent(infinirtEvent_t event, infinirtStream_t stream = nullptr);
bool queryEvent(infinirtEvent_t event);
void synchronizeEvent(infinirtEvent_t event);
void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
} // namespace context } // namespace context
} // namespace infinicore } // namespace infinicore
#pragma once
#include "device.hpp"
#include "infinirt.h"
#include <memory>
#include <stdexcept>
namespace infinicore {
/**
* @brief A device event for timing operations and synchronization across devices.
*
* Similar to torch.cuda.Event, this class provides functionality to:
* - Record events on specific device streams
* - Synchronize with events
* - Measure elapsed time between events
* - Query event completion status
* - Make streams wait for events
*/
class DeviceEvent {
private:
infinirtEvent_t event_; // Underlying event handle
Device device_; // Device where this event was created
bool is_recorded_; // Whether the event has been recorded
public:
/**
* @brief Construct a new DeviceEvent on the current device.
*/
DeviceEvent();
/**
* @brief Construct a new DeviceEvent on the current device with specific flags.
* @param flags Event creation flags (e.g., for timing, blocking sync)
*/
explicit DeviceEvent(uint32_t flags);
/**
* @brief Construct a new DeviceEvent on a specific device.
* @param device Target device for this event
*/
explicit DeviceEvent(Device device);
/**
* @brief Construct a new DeviceEvent on a specific device with flags.
* @param device Target device for this event
* @param flags Event creation flags
*/
DeviceEvent(Device device, uint32_t flags);
// Disallow copying
DeviceEvent(const DeviceEvent &) = delete;
DeviceEvent &operator=(const DeviceEvent &) = delete;
/**
* @brief Move constructor.
*/
DeviceEvent(DeviceEvent &&other) noexcept;
/**
* @brief Move assignment operator.
*/
DeviceEvent &operator=(DeviceEvent &&other) noexcept;
/**
* @brief Destroy the DeviceEvent and release underlying resources.
*/
~DeviceEvent();
/**
* @brief Record the event on the current stream of its device.
*/
void record();
/**
* @brief Record the event on a specific stream.
* @param stream Stream to record the event on
*/
void record(infinirtStream_t stream);
/**
* @brief Wait for the event to complete (blocking).
*/
void synchronize();
/**
* @brief Check if the event has been completed.
* @return true if completed, false otherwise
*/
bool query() const;
/**
* @brief Calculate elapsed time between this event and another event (in milliseconds).
* @param other The other event to compare with
* @return Elapsed time in milliseconds
* @throws std::runtime_error if events are on different devices or not recorded
*/
float elapsed_time(const DeviceEvent &other) const;
/**
* @brief Make a stream wait for this event to complete.
* @param stream Stream to make wait for this event (nullptr for current stream)
*/
void wait(infinirtStream_t stream = nullptr) const;
/**
* @brief Get the device where this event was created.
* @return Device associated with this event
*/
Device device() const { return device_; }
/**
* @brief Get the underlying event handle.
* @return Raw event handle
*/
infinirtEvent_t get() const { return event_; }
/**
* @brief Check if the event has been recorded.
* @return true if recorded, false otherwise
*/
bool is_recorded() const { return is_recorded_; }
};
} // namespace infinicore
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define __INFINIRT_API_H__ #define __INFINIRT_API_H__
#include "infinicore.h" #include "infinicore.h"
#include <stdint.h>
typedef void *infinirtStream_t; typedef void *infinirtStream_t;
typedef void *infinirtEvent_t; typedef void *infinirtEvent_t;
...@@ -27,11 +28,20 @@ typedef enum { ...@@ -27,11 +28,20 @@ typedef enum {
INFINIRT_EVENT_NOT_READY = 1, INFINIRT_EVENT_NOT_READY = 1,
} infinirtEventStatus_t; } infinirtEventStatus_t;
// Event flags for precise timing
typedef enum {
INFINIRT_EVENT_DEFAULT = 0x0, // Default event creation flags
INFINIRT_EVENT_DISABLE_TIMING = 0x1, // Event will not record timing data
INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization
} infinirtEventFlags_t;
__C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr); __C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr);
__C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags);
__C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream); __C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream);
__C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr); __C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr);
__C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event); __C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event); __C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end);
// Memory // Memory
typedef enum { typedef enum {
......
import contextlib import contextlib
import infinicore.nn as nn import infinicore.nn as nn
# Import context functions
from infinicore.context import (
get_device,
get_device_count,
get_stream,
set_device,
sync_device,
sync_stream,
)
from infinicore.device import device from infinicore.device import device
from infinicore.device_event import DeviceEvent
from infinicore.dtype import ( from infinicore.dtype import (
bfloat16, bfloat16,
bool, bool,
...@@ -52,8 +63,16 @@ __all__ = [ ...@@ -52,8 +63,16 @@ __all__ = [
"nn", "nn",
# Classes. # Classes.
"device", "device",
"DeviceEvent",
"dtype", "dtype",
"Tensor", "Tensor",
# Context functions.
"get_device",
"get_device_count",
"get_stream",
"set_device",
"sync_device",
"sync_stream",
# Data Types. # Data Types.
"bfloat16", "bfloat16",
"bool", "bool",
......
import infinicore.device
from infinicore.lib import _infinicore
def get_device():
"""Get the current active device.
Returns:
device: The current active device object
"""
return _infinicore.get_device()
def get_device_count(device_type):
"""Get the number of available devices of a specific type.
Args:
device_type (str): The type of device to count (e.g., "cuda", "cpu", "npu")
Returns:
int: The number of available devices of the specified type
"""
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
def set_device(device):
"""Set the current active device.
Args:
device: The device to set as active
"""
_infinicore.set_device(device._underlying)
def sync_stream():
"""Synchronize the current stream."""
_infinicore.sync_stream()
def sync_device():
"""Synchronize the current device."""
_infinicore.sync_device()
def get_stream():
"""Get the current stream.
Returns:
stream: The current stream object
"""
return _infinicore.get_stream()
import infinicore.device
from infinicore.lib import _infinicore
class DeviceEvent:
"""A device event for timing operations and synchronization across devices.
Similar to torch.cuda.Event, this class provides functionality to:
- Record events on specific device streams
- Synchronize with events
- Measure elapsed time between events
- Query event completion status
- Make streams wait for events
Args:
enable_timing: Whether the event should record timing data. Default: False.
blocking: Whether to use blocking synchronization. Default: False.
interprocess: Whether the event can be used for inter-process communication. Default: False.
external: Whether the event is an external event. Default: False.
device: Target device for this event. If None, uses current device.
"""
def __init__(self, enable_timing=False, device=None):
# Build flags based on parameters
flags = 0
if not enable_timing:
flags |= 0x2 # DISABLE_TIMING
# if blocking:
# flags |= 0x1 # BLOCKING_SYNC
# Store parameters for reference
self._enable_timing = enable_timing
# self._blocking = blocking
# self._interprocess = interprocess
# self._external = external
if device is None:
# Use current device
if flags == 0:
self._underlying = _infinicore.DeviceEvent()
else:
self._underlying = _infinicore.DeviceEvent(flags)
elif flags == 0:
# Construct with device only
self._underlying = _infinicore.DeviceEvent(device._underlying)
else:
# Construct with both device and flags
self._underlying = _infinicore.DeviceEvent(device._underlying, flags)
def record(self, stream=None):
"""Record the event.
Args:
stream: Stream to record the event on. If None, uses current stream.
"""
if stream is None:
self._underlying.record()
else:
self._underlying.record(stream)
def synchronize(self):
"""Wait for the event to complete (blocking)."""
self._underlying.synchronize()
def query(self):
"""Check if the event has been completed.
Returns:
bool: True if completed, False otherwise.
"""
return self._underlying.query()
def elapsed_time(self, other):
"""Calculate elapsed time between this event and another event.
Args:
other: The other DeviceEvent to compare with
Returns:
float: Elapsed time in milliseconds between this event and the other event
Raises:
RuntimeError: If events are on different devices or not recorded,
or if timing is disabled on either event
"""
if not self._enable_timing or not other._enable_timing:
raise RuntimeError("Cannot measure elapsed time when timing is disabled")
return self._underlying.elapsed_time(other._underlying)
def wait(self, stream=None):
"""Make a stream wait for this event to complete.
Args:
stream: Stream to make wait for this event. If None, uses current stream.
"""
self._underlying.wait(stream)
@property
def device(self):
"""Get the device where this event was created."""
return infinicore.device._from_infinicore_device(self._underlying.device)
@property
def is_recorded(self):
"""Check if the event has been recorded."""
return self._underlying.is_recorded
@property
def enable_timing(self):
"""Whether this event records timing data."""
return self._enable_timing
@property
def blocking(self):
"""Whether this event uses blocking synchronization."""
return self._blocking
@property
def interprocess(self):
"""Whether this event can be used for inter-process communication."""
return self._interprocess
def __repr__(self):
flags_str = []
if not self._enable_timing:
flags_str.append("timing_disabled")
if self._blocking:
flags_str.append("blocking")
if self._interprocess:
flags_str.append("interprocess")
if self._external:
flags_str.append("external")
if not flags_str:
flags_str.append("default")
return f"DeviceEvent(device={self.device}, flags={', '.join(flags_str)}, recorded={self.is_recorded})"
...@@ -58,7 +58,7 @@ ContextImpl &ContextImpl::singleton() { ...@@ -58,7 +58,7 @@ ContextImpl &ContextImpl::singleton() {
} }
ContextImpl::ContextImpl() { ContextImpl::ContextImpl() {
std::vector<int> device_counter(size_t(Device::Type::COUNT)); std::vector<int> device_counter(static_cast<size_t>(Device::Type::COUNT));
INFINICORE_CHECK_ERROR(infinirtGetAllDeviceCount(device_counter.data())); INFINICORE_CHECK_ERROR(infinirtGetAllDeviceCount(device_counter.data()));
// Reserve runtime slot for all devices. // Reserve runtime slot for all devices.
...@@ -145,6 +145,39 @@ void memcpyH2H(void *dst, const void *src, size_t size) { ...@@ -145,6 +145,39 @@ void memcpyH2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size); return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size);
} }
// Timing API implementations
infinirtEvent_t createEvent() {
return ContextImpl::singleton().getCurrentRuntime()->createEvent();
}
infinirtEvent_t createEventWithFlags(uint32_t flags) {
return ContextImpl::singleton().getCurrentRuntime()->createEventWithFlags(flags);
}
void recordEvent(infinirtEvent_t event, infinirtStream_t stream) {
ContextImpl::singleton().getCurrentRuntime()->recordEvent(event, stream);
}
bool queryEvent(infinirtEvent_t event) {
return ContextImpl::singleton().getCurrentRuntime()->queryEvent(event);
}
void synchronizeEvent(infinirtEvent_t event) {
ContextImpl::singleton().getCurrentRuntime()->synchronizeEvent(event);
}
void destroyEvent(infinirtEvent_t event) {
ContextImpl::singleton().getCurrentRuntime()->destroyEvent(event);
}
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end) {
return ContextImpl::singleton().getCurrentRuntime()->elapsedTime(start, end);
}
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event);
}
} // namespace context } // namespace context
} // namespace infinicore } // namespace infinicore
...@@ -88,6 +88,54 @@ void Runtime::memcpyD2D(void *dst, const void *src, size_t size) { ...@@ -88,6 +88,54 @@ void Runtime::memcpyD2D(void *dst, const void *src, size_t size) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_)); INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_));
} }
// Timing method implementations
infinirtEvent_t Runtime::createEvent() {
infinirtEvent_t event;
INFINICORE_CHECK_ERROR(infinirtEventCreate(&event));
return event;
}
infinirtEvent_t Runtime::createEventWithFlags(uint32_t flags) {
infinirtEvent_t event;
INFINICORE_CHECK_ERROR(infinirtEventCreateWithFlags(&event, flags));
return event;
}
void Runtime::recordEvent(infinirtEvent_t event, infinirtStream_t stream) {
if (stream == nullptr) {
stream = stream_;
}
INFINICORE_CHECK_ERROR(infinirtEventRecord(event, stream));
}
bool Runtime::queryEvent(infinirtEvent_t event) {
infinirtEventStatus_t status;
INFINICORE_CHECK_ERROR(infinirtEventQuery(event, &status));
return status == INFINIRT_EVENT_COMPLETE;
}
void Runtime::synchronizeEvent(infinirtEvent_t event) {
INFINICORE_CHECK_ERROR(infinirtEventSynchronize(event));
}
void Runtime::destroyEvent(infinirtEvent_t event) {
INFINICORE_CHECK_ERROR(infinirtEventDestroy(event));
}
float Runtime::elapsedTime(infinirtEvent_t start, infinirtEvent_t end) {
float ms;
INFINICORE_CHECK_ERROR(infinirtEventElapsedTime(&ms, start, end));
return ms;
}
void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
// Use current stream if no specific stream is provided
if (stream == nullptr) {
stream = stream_;
}
INFINICORE_CHECK_ERROR(infinirtStreamWaitEvent(stream, event));
}
std::string Runtime::toString() const { std::string Runtime::toString() const {
return fmt::format("Runtime({})", device_.toString()); return fmt::format("Runtime({})", device_.toString());
} }
......
...@@ -38,6 +38,16 @@ public: ...@@ -38,6 +38,16 @@ public:
void memcpyD2H(void *dst, const void *src, size_t size); void memcpyD2H(void *dst, const void *src, size_t size);
void memcpyD2D(void *dst, const void *src, size_t size); void memcpyD2D(void *dst, const void *src, size_t size);
// Timing methods
infinirtEvent_t createEvent();
infinirtEvent_t createEventWithFlags(uint32_t flags);
void recordEvent(infinirtEvent_t event, infinirtStream_t stream = nullptr);
bool queryEvent(infinirtEvent_t event);
void synchronizeEvent(infinirtEvent_t event);
void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
std::string toString() const; std::string toString() const;
friend class ContextImpl; friend class ContextImpl;
......
#include "infinicore.hpp"
namespace infinicore {
DeviceEvent::DeviceEvent()
: device_(context::getDevice()), is_recorded_(false) {
event_ = context::createEvent();
}
DeviceEvent::DeviceEvent(uint32_t flags)
: device_(context::getDevice()), is_recorded_(false) {
event_ = context::createEventWithFlags(flags);
}
DeviceEvent::DeviceEvent(Device device)
: device_(device), is_recorded_(false) {
// Switch to target device for event creation
Device current_device = context::getDevice();
context::setDevice(device_);
event_ = context::createEvent();
// Restore original device
context::setDevice(current_device);
}
DeviceEvent::DeviceEvent(Device device, uint32_t flags)
: device_(device), is_recorded_(false) {
// Switch to target device for event creation
Device current_device = context::getDevice();
context::setDevice(device_);
event_ = context::createEventWithFlags(flags);
// Restore original device
context::setDevice(current_device);
}
DeviceEvent::DeviceEvent(DeviceEvent &&other) noexcept
: event_(other.event_), device_(other.device_), is_recorded_(other.is_recorded_) {
other.event_ = nullptr;
other.is_recorded_ = false;
}
DeviceEvent &DeviceEvent::operator=(DeviceEvent &&other) noexcept {
if (this != &other) {
// Clean up current resources
if (event_ != nullptr) {
context::destroyEvent(event_);
}
// Transfer ownership
event_ = other.event_;
device_ = other.device_;
is_recorded_ = other.is_recorded_;
// Reset source
other.event_ = nullptr;
other.is_recorded_ = false;
}
return *this;
}
DeviceEvent::~DeviceEvent() {
if (event_ != nullptr) {
context::destroyEvent(event_);
}
}
void DeviceEvent::record() {
Device current_device = context::getDevice();
// Ensure we're on the correct device
if (current_device != device_) {
context::setDevice(device_);
}
context::recordEvent(event_);
is_recorded_ = true;
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
}
void DeviceEvent::record(infinirtStream_t stream) {
Device current_device = context::getDevice();
// Ensure we're on the correct device
if (current_device != device_) {
context::setDevice(device_);
}
context::recordEvent(event_, stream);
is_recorded_ = true;
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
}
void DeviceEvent::synchronize() {
Device current_device = context::getDevice();
// Ensure we're on the correct device
if (current_device != device_) {
context::setDevice(device_);
}
context::synchronizeEvent(event_);
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
}
bool DeviceEvent::query() const {
Device current_device = context::getDevice();
bool result = false;
// Ensure we're on the correct device
if (current_device != device_) {
context::setDevice(device_);
}
result = context::queryEvent(event_);
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
return result;
}
float DeviceEvent::elapsed_time(const DeviceEvent &other) const {
// Both events must be on the same device
if (device_ != other.device_) {
throw std::runtime_error("Cannot measure elapsed time between events on different devices");
}
// Both events must be recorded
if (!is_recorded_ || !other.is_recorded_) {
throw std::runtime_error("Both events must be recorded before measuring elapsed time");
}
Device current_device = context::getDevice();
// Switch to the device where events reside
if (current_device != device_) {
context::setDevice(device_);
}
float elapsed_ms = context::elapsedTime(event_, other.event_);
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
return elapsed_ms;
}
void DeviceEvent::wait(infinirtStream_t stream) const {
Device current_device = context::getDevice();
// Ensure we're on the correct device
if (current_device != device_) {
context::setDevice(device_);
}
// Make the stream wait for this event
context::streamWaitEvent(stream, event_);
// Restore original device if we changed it
if (current_device != device_) {
context::setDevice(current_device);
}
}
} // namespace infinicore
...@@ -9,8 +9,21 @@ namespace py = pybind11; ...@@ -9,8 +9,21 @@ namespace py = pybind11;
namespace infinicore::context { namespace infinicore::context {
inline void bind(py::module &m) { inline void bind(py::module &m) {
m.def("get_device", &getDevice); // Device management
m.def("get_device_count", &getDeviceCount); m.def("get_device", &getDevice, "Get the current active device");
m.def("get_device_count", &getDeviceCount,
"Get the number of available devices of a specific type",
py::arg("device_type"));
m.def("set_device", &setDevice,
"Set the current active device",
py::arg("device"));
// Stream and handle management
m.def("get_stream", &getStream, "Get the current stream");
// Synchronization
m.def("sync_stream", &syncStream, "Synchronize the current stream");
m.def("sync_device", &syncDevice, "Synchronize the current device");
} }
} // namespace infinicore::context } // namespace infinicore::context
\ No newline at end of file
#pragma once
#include "infinicore.hpp"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace infinicore::device_event {
inline void bind(py::module &m) {
py::class_<DeviceEvent>(m, "DeviceEvent")
.def(py::init<>(), "Construct a DeviceEvent on the current device")
.def(py::init<uint32_t>(), "Construct a DeviceEvent with specific flags", py::arg("flags"))
.def(py::init<Device>(), "Construct a DeviceEvent on a specific device", py::arg("device"))
.def(py::init<Device, uint32_t>(), "Construct a DeviceEvent on a specific device with flags",
py::arg("device"), py::arg("flags"))
.def("record", py::overload_cast<>(&DeviceEvent::record),
"Record the event on the current stream of its device")
.def("record", py::overload_cast<infinirtStream_t>(&DeviceEvent::record),
"Record the event on a specific stream", py::arg("stream"))
.def("synchronize", &DeviceEvent::synchronize,
"Wait for the event to complete (blocking)")
.def("query", &DeviceEvent::query,
"Check if the event has been completed")
.def("elapsed_time", &DeviceEvent::elapsed_time,
"Calculate elapsed time between this event and another event (in milliseconds)",
py::arg("other"))
.def("wait", &DeviceEvent::wait,
"Make a stream wait for this event to complete",
py::arg("stream") = nullptr)
.def_property_readonly("device", &DeviceEvent::device,
"Get the device where this event was created")
.def_property_readonly("is_recorded", &DeviceEvent::is_recorded,
"Check if the event has been recorded");
}
} // namespace infinicore::device_event
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "context.hpp" #include "context.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_event.hpp"
#include "dtype.hpp" #include "dtype.hpp"
#include "ops.hpp" #include "ops.hpp"
#include "tensor.hpp" #include "tensor.hpp"
...@@ -12,6 +13,7 @@ namespace infinicore { ...@@ -12,6 +13,7 @@ namespace infinicore {
PYBIND11_MODULE(_infinicore, m) { PYBIND11_MODULE(_infinicore, m) {
context::bind(m); context::bind(m);
device::bind(m); device::bind(m);
device_event::bind(m);
dtype::bind(m); dtype::bind(m);
ops::bind(m); ops::bind(m);
tensor::bind(m); tensor::bind(m);
......
...@@ -64,6 +64,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -64,6 +64,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_ACLRT(aclrtRecordEvent((aclrtEvent)event, (aclrtStream)stream)); CHECK_ACLRT(aclrtRecordEvent((aclrtEvent)event, (aclrtStream)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -90,6 +94,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -90,6 +94,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_ACLRT(aclrtMallocAlign32(p_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACLRT(aclrtMallocAlign32(p_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -51,6 +51,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -51,6 +51,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_BANGRT(cnrtPlaceNotifier((cnrtNotifier_t)event, (cnrtQueue_t)stream)); CHECK_BANGRT(cnrtPlaceNotifier((cnrtNotifier_t)event, (cnrtQueue_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -78,6 +82,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -78,6 +82,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_BANGRT(cnrtMalloc(p_ptr, size)); CHECK_BANGRT(cnrtMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
#include "infinirt_cpu.h" #include "infinirt_cpu.h"
#include <chrono>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
...@@ -34,23 +35,50 @@ infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) { ...@@ -34,23 +35,50 @@ infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
} }
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_NOT_IMPLEMENTED; // For CPU implementation, we use a simple timestamp as event
auto now = std::chrono::steady_clock::now();
auto *timestamp = new std::chrono::steady_clock::time_point(now);
*event_ptr = timestamp;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
// CPU implementation ignores flags for simplicity
return eventCreate(event_ptr);
} }
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
return INFINI_STATUS_NOT_IMPLEMENTED; // Update the event timestamp
auto *timestamp = static_cast<std::chrono::steady_clock::time_point *>(event);
*timestamp = std::chrono::steady_clock::now();
return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) { infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) {
return INFINI_STATUS_NOT_IMPLEMENTED; // CPU events are always complete immediately
*status_ptr = INFINIRT_EVENT_COMPLETE;
return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventSynchronize(infinirtEvent_t event) { infiniStatus_t eventSynchronize(infinirtEvent_t event) {
return INFINI_STATUS_NOT_IMPLEMENTED; // CPU events are synchronized immediately
return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventDestroy(infinirtEvent_t event) { infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_NOT_IMPLEMENTED; auto *timestamp = static_cast<std::chrono::steady_clock::time_point *>(event);
delete timestamp;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
auto *start_time = static_cast<std::chrono::steady_clock::time_point *>(start);
auto *end_time = static_cast<std::chrono::steady_clock::time_point *>(end);
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(*end_time - *start_time);
*ms_ptr = static_cast<float>(duration.count()) / 1000.0f; // Convert microseconds to milliseconds
return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
......
...@@ -53,6 +53,23 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { ...@@ -53,6 +53,23 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
cudaEvent_t event;
unsigned int cuda_flags = cudaEventDefault;
// Convert infinirt flags to CUDA flags
if (flags & INFINIRT_EVENT_DISABLE_TIMING) {
cuda_flags |= cudaEventDisableTiming;
}
if (flags & INFINIRT_EVENT_BLOCKING_SYNC) {
cuda_flags |= cudaEventBlockingSync;
}
CHECK_CUDART(cudaEventCreateWithFlags(&event, cuda_flags));
*event_ptr = event;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_CUDART(cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream)); CHECK_CUDART(cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -80,6 +97,11 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -80,6 +97,11 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
CHECK_CUDART(cudaEventElapsedTime(ms_ptr, (cudaEvent_t)start, (cudaEvent_t)end));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) { infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_CUDART(cudaMalloc(p_ptr, size)); CHECK_CUDART(cudaMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -126,6 +126,10 @@ __C infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr) { ...@@ -126,6 +126,10 @@ __C infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr) {
INFINIRT_CALL_DEVICE_API(eventCreate, (event_ptr)); INFINIRT_CALL_DEVICE_API(eventCreate, (event_ptr));
} }
__C infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) {
INFINIRT_CALL_DEVICE_API(eventCreateWithFlags, (event_ptr, flags));
}
__C infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream) { __C infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream) {
INFINIRT_CALL_DEVICE_API(eventRecord, (event, stream)); INFINIRT_CALL_DEVICE_API(eventRecord, (event, stream));
} }
...@@ -142,6 +146,10 @@ __C infiniStatus_t infinirtEventDestroy(infinirtEvent_t event) { ...@@ -142,6 +146,10 @@ __C infiniStatus_t infinirtEventDestroy(infinirtEvent_t event) {
INFINIRT_CALL_DEVICE_API(eventDestroy, (event)); INFINIRT_CALL_DEVICE_API(eventDestroy, (event));
} }
__C infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
INFINIRT_CALL_DEVICE_API(eventElapsedTime, (ms_ptr, start, end));
}
__C infiniStatus_t infinirtMalloc(void **p_ptr, size_t size) { __C infiniStatus_t infinirtMalloc(void **p_ptr, size_t size) {
INFINIRT_CALL_DEVICE_API(mallocDevice, (p_ptr, size)); INFINIRT_CALL_DEVICE_API(mallocDevice, (p_ptr, size));
} }
......
#ifndef __INFINIRT_IMPL_H__ #ifndef __INFINIRT_IMPL_H__
#define __INFINIRT_IMPL_H__ #define __INFINIRT_IMPL_H__
#include "infinirt.h" #include "infinirt.h"
#include <stdint.h>
#define INFINIRT_DEVICE_API(IMPL, COUNT) \ #define INFINIRT_DEVICE_API(IMPL, COUNT) \
infiniStatus_t getDeviceCount(int *count) COUNT; \ infiniStatus_t getDeviceCount(int *count) COUNT; \
...@@ -13,10 +14,12 @@ ...@@ -13,10 +14,12 @@
infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) IMPL; \ infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) IMPL; \
\ \
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) IMPL; \ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) IMPL; \
infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) IMPL; \
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) IMPL; \ infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) IMPL; \
infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) IMPL; \ infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) IMPL; \
infiniStatus_t eventSynchronize(infinirtEvent_t event) IMPL; \ infiniStatus_t eventSynchronize(infinirtEvent_t event) IMPL; \
infiniStatus_t eventDestroy(infinirtEvent_t event) IMPL; \ infiniStatus_t eventDestroy(infinirtEvent_t event) IMPL; \
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) IMPL; \
\ \
infiniStatus_t mallocDevice(void **p_ptr, size_t size) IMPL; \ infiniStatus_t mallocDevice(void **p_ptr, size_t size) IMPL; \
infiniStatus_t mallocHost(void **p_ptr, size_t size) IMPL; \ infiniStatus_t mallocHost(void **p_ptr, size_t size) IMPL; \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment