infinirt_bang.cc 4.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include "../../utils.h"
#include "infinirt_bang.h"
#include "cnrt.h"

#define CHECK_BANGRT(RT_API) CHECK_INTERNAL(RT_API, cnrtSuccess)

namespace infinirt::bang {
infiniStatus_t getDeviceCount(int *count) {
    CHECK_BANGRT(cnrtGetDeviceCount(count));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t setDevice(int device_id) {
    CHECK_BANGRT(cnrtSetDevice(device_id));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t deviceSynchronize() {
    CHECK_BANGRT(cnrtSyncDevice());
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
    cnrtQueue_t queue;
    CHECK_BANGRT(cnrtQueueCreate(&stream));
    *stream_ptr = queue;
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t streamDestroy(infinirtStream_t stream) {
    CHECK_BANGRT(cnrtQueueDestroy((cnrtQueue_t)stream));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t streamSynchronize(infinirtStream_t stream) {
    CHECK_BANGRT(cnrtQueueSync((cnrtQueue_t)stream));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
    CHECK_BANGRT(cnrtQueueWaitNotifier((cnrtNotifier_t)event, (cnrtQueue_t)stream, 0));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
    cnrtNotifier_t notifier;
    CHECK_BANGRT(cnrtNotifierCreate(&notifier));
    *event_ptr = notifier;
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
    CHECK_BANGRT(cnrtPlaceNotifier((cnrtNotifier_t)event, (cnrtQueue_t)stream));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) {
    auto status = cnrtQueryNotifier((cnrtQueue_t)stream);
    if (status == cnrtSuccess) {
        *status_ptr = INFINIRT_EVENT_COMPLETE;
    } else if (status == cnrtErrorBusy) {
        *status_ptr = INFINIRT_EVENT_NOT_READY;
    } else {
        CHECK_BANGRT(status);
    }
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t eventSynchronize(infinirtEvent_t event) {
    CHECK_BANGRT(cnrtWaitNotifier((cnrtNotifier_t)event));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t eventDestroy(infinirtEvent_t event) {
    CHECK_BANGRT(cnrtNotifierDestroy((cnrtNotifier_t)event));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
    CHECK_BANGRT(cnrtMalloc(p_ptr, size));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t mallocHost(void **p_ptr, size_t size) {
    CHECK_BANGRT(cnrtHostMalloc(p_ptr, size));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t freeDevice(void *ptr) {
    CHECK_BANGRT(cnrtFree(ptr));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t freeHost(void *ptr) {
    CHECK_BANGRT(cnrtFreeHost(ptr));
    return INFINI_STATUS_SUCCESS;
}

cnrtMemTransDir_t toBangMemcpyKind(infinirtMemcpyKind_t kind) {
    switch (kind) {
    case INFINIRT_MEMCPY_H2D:
        return cnrtMemcpyHostToDev;
    case INFINIRT_MEMCPY_D2H:
        return cnrtMemcpyDevToHost;
    case INFINIRT_MEMCPY_D2D:
106
        return cnrtMemcpyDevToDev;
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    case INFINIRT_MEMCPY_H2H:
        return cnrtMemcpyHostToHost;
    default:
        return cnrtMemcpyNoDirection;
    }
}

infiniStatus_t memcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind) {
    CHECK_BANGRT(cnrtMemcpy(dst, src, size, toBangMemcpyKind(kind)));
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) {
    CHECK_BANGRT(cnrtMemcpyAsync_V2(dst, src, size, (cnrtQueue_t)stream, toBangMemcpyKind(kind)));
    return INFINI_STATUS_SUCCESS;
}

124
// Does not support async malloc. Use blocking-style malloc instead
125
infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
126
127
    CHECK_BANGRT(cnrtMalloc(p_ptr, size));
    return INFINI_STATUS_SUCCESS;
128
129
}

130
// Does not support async free. Use blocking-style free instead
131
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
132
133
    CHECK_BANGRT(cnrtFree(ptr));
    return INFINI_STATUS_SUCCESS;
134
135
}
} // namespace infinirt::bang