event.hpp 1.22 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#pragma once
Chenggang Zhao's avatar
Chenggang Zhao committed
2

lijian6's avatar
lijian6 committed
3
#include <ATen/hip/HIPContext.h>
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
6
7
8
9
10
11
12
#include "kernels/exception.cuh"

namespace deep_ep {

struct EventHandle {
    std::shared_ptr<torch::Event> event;

    EventHandle() {
        event = std::make_shared<torch::Event>(torch::kCUDA);
lijian6's avatar
lijian6 committed
13
        event->record(at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
Chenggang Zhao's avatar
Chenggang Zhao committed
14
15
    }

lijian6's avatar
lijian6 committed
16
    explicit EventHandle(const at::hip::HIPStreamMasqueradingAsCUDA &stream) {
Chenggang Zhao's avatar
Chenggang Zhao committed
17
18
19
20
        event = std::make_shared<torch::Event>(torch::kCUDA);
        event->record(stream);
    }

lijian6's avatar
lijian6 committed
21
    EventHandle(const EventHandle &other) = default;
Chenggang Zhao's avatar
Chenggang Zhao committed
22
23

    void current_stream_wait() const {
lijian6's avatar
lijian6 committed
24
        at::hip::getCurrentHIPStreamMasqueradingAsCUDA().unwrap().wait(*event);
Chenggang Zhao's avatar
Chenggang Zhao committed
25
26
27
    }
};

lijian6's avatar
lijian6 committed
28
inline torch::Event create_event(const at::hip::HIPStreamMasqueradingAsCUDA &s) {
Chenggang Zhao's avatar
Chenggang Zhao committed
29
30
31
32
33
    auto event = torch::Event(torch::kCUDA);
    event.record(s);
    return event;
}

lijian6's avatar
lijian6 committed
34
35
inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s_0,
                        const at::hip::HIPStreamMasqueradingAsCUDA &s_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
36
37
38
39
    EP_HOST_ASSERT(s_0.id() != s_1.id());
    s_0.unwrap().wait(create_event(s_1));
}

lijian6's avatar
lijian6 committed
40
inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s, const EventHandle &event) {
Chenggang Zhao's avatar
Chenggang Zhao committed
41
42
43
44
    s.unwrap().wait(*event.event);
}

} // namespace deep_ep