shmem_wrapper.cuh 8.23 KB
Newer Older
1
2
#pragma once
/*
lijian6's avatar
lijian6 committed
3
 * Temporary wrapper for for platform specific DUSHMEM and rocSHMEM functions.
4
 * Once hipify or hipify-torch fully supports this mapping, this file has to be
lijian6's avatar
lijian6 committed
5
 * removed and according dushmem* functions restored.
6
7
8
9
10
 */
#ifndef DISABLE_ROCSHMEM

#include "configs.cuh"

lijian6's avatar
lijian6 committed
11
#ifndef FORCE_DUSHMEM_API
12
13
14
15
16
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
lijian6's avatar
lijian6 committed
17
#include <device_host_transport/dushmem_common_ibgda.h>
18
#include <infiniband/mlx5dv.h>
lijian6's avatar
lijian6 committed
19
20
21
#include <dushmem.h>
#include <dushmemx.h>
#include <non_abi/device/threadgroup/dushmemi_common_device_defines.cuh>
22
23
24
25
26
#endif

namespace deep_ep::internode {

// rocSHMEM wrapper
lijian6's avatar
lijian6 committed
27
#ifndef FORCE_DUSHMEM_API
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
using shmem_team_t = rocshmem::rocshmem_team_t;
using shmem_team_config_t = rocshmem::rocshmem_team_config_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = rocshmem::ROCSHMEM_TEAM_INVALID;
inline shmem_team_t& EP_SHMEM_TEAM_WORLD = rocshmem::ROCSHMEM_TEAM_WORLD;

using shmemx_uniqueid_t = rocshmem::rocshmem_uniqueid_t;
using shmemx_init_attr_t = rocshmem::rocshmem_init_attr_t;

constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID;

__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
    return rocshmem::rocshmem_get_uniqueid(uid);
}

__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
    shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
    return rocshmem::rocshmem_set_attr_uniqueid_args(rank, nranks, uid, attr);
}

__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
    return rocshmem::rocshmem_init_attr(flags, attr);
}

__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
    int start, int stride, int size,
    const shmem_team_config_t *config,
    long config_mask, shmem_team_t *new_team) {
    return rocshmem::rocshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}

__host__ inline void shmem_barrier_all() {
    rocshmem::rocshmem_barrier_all();
}

__device__ inline void shmem_device_barrier_all() {
    rocshmem::rocshmem_barrier_all();
}

__device__ inline void shmem_barrier(shmem_team_t team) {
    rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, team);
}

__host__ inline int shmem_my_pe(){
    return rocshmem::rocshmem_my_pe();
}

__host__ inline void shmem_free(void *ptr){
    rocshmem::rocshmem_free(ptr);
}

__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
    auto alloc_size = ALIGN(size, alignment);
    return rocshmem::rocshmem_malloc(alloc_size);
}

__host__ inline void shmem_finalize() {
    rocshmem::rocshmem_finalize();
}

__host__ inline void shmem_team_destroy(shmem_team_t team) {
    rocshmem::rocshmem_team_destroy(team);
}

__device__ inline void shmem_fence() {
    rocshmem::rocshmem_fence();
}

__device__ inline void shmem_int_put_nbi(
    int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_int_put_nbi(dest, source, nelems, pe);
}

__device__ inline void shmemx_int_put_nbi_warp(
    int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_int_put_nbi_wave(dest, source, nelems, pe);
}

__device__ inline void shmemx_int8_put_nbi_warp(
    signed char *dest, const signed char *source, size_t nelems, int pe) {
    rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
}

lijian6's avatar
lijian6 committed
110
__device__ inline void shmem_signal_op_add(uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
111
112
113
    rocshmem::rocshmem_ulong_atomic_add(dest, value, pe);
}

114
115
116
117
118
__device__ inline void shmem_long_atomic_add(
    long *dest, long value, int pe) {
    rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}

119
#if defined(ROCM_USE_MULTIQP)
120
121
122
123
__device__ inline void shmem_qp_quiet(int idx_qp) {
    rocshmem::rocshmem_quiet_dp(idx_qp);
}

124
125
126
127
128
129
130
131
132
133
134
__device__ inline void shmemx_int8_put_nbi_warp_dp(
    signed char *dest, const signed char *source, size_t nelems, int qp_idx, int pe) {
    rocshmem::rocshmem_schar_put_nbi_wave_dp(dest, source, nelems, qp_idx, pe);
}

__device__ inline void shmem_long_atomic_add_dp(
    long *dest, long value, int qp_idx, int pe) {
    rocshmem::rocshmem_long_atomic_add_dp(dest, value, qp_idx, pe);
}
#endif

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#if !defined(ROCM_DISABLE_CTX)
using shmem_ctx_t = rocshmem::rocshmem_ctx_t;

__device__ inline int shmem_wg_ctx_create(shmem_ctx_t *ctx) {
    return rocshmem::rocshmem_wg_ctx_create(0, ctx);
}

__device__ inline void shmem_wg_ctx_destroy(shmem_ctx_t *ctx) {
    rocshmem::rocshmem_wg_ctx_destroy(ctx);
}

__device__ inline void shmem_ctx_quiet(shmem_ctx_t ctx) {
    rocshmem::rocshmem_ctx_quiet(ctx);
}

__device__ inline void shmem_ctx_ulong_atomic_add(
    shmem_ctx_t ctx, uint64_t *dest, uint64_t value, int pe) {
    rocshmem::rocshmem_ctx_ulong_atomic_add(ctx, dest, value, pe);
}

__device__ inline void shmem_ctx_long_atomic_add(
    shmem_ctx_t ctx, long *dest, long value, int pe) {
    rocshmem::rocshmem_ctx_long_atomic_add(ctx, dest, value, pe);
}

__device__ inline void shmem_ctx_schar_put_nbi_warp(
    shmem_ctx_t ctx, signed char *dest, const signed char *source, size_t nelems, int pe) {
    rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx, dest, source, nelems, pe);
}

__device__ inline void shmem_ctx_int_put_nbi_warp(
    shmem_ctx_t ctx, int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_ctx_int_put_nbi_wave(ctx, dest, source, nelems, pe);
}

#endif

#else

lijian6's avatar
lijian6 committed
174
// DUSHMEM wrapper
175
176
177
178
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif

lijian6's avatar
lijian6 committed
179
180
181
182
183
184
185
using shmem_team_t = dushmem_team_t;
using shmem_team_config_t = dushmem_team_config_t;
using shmemx_uniqueid_t = dushmemx_uniqueid_t;
using shmemx_init_attr_t = dushmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = DUSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = DUSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = DUSHMEMX_INIT_WITH_UNIQUEID;
186
187
188


__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
lijian6's avatar
lijian6 committed
189
    return dushmemx_get_uniqueid(uid);
190
191
192
193
}

__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
    shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
lijian6's avatar
lijian6 committed
194
    return dushmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
195
196
197
198
}


__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
lijian6's avatar
lijian6 committed
199
    return dushmemx_init_attr(flags, attr);
200
201
202
203
204
205
}

__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
    int start, int stride, int size,
    const shmem_team_config_t *config,
    long config_mask, shmem_team_t *new_team) {
lijian6's avatar
lijian6 committed
206
    return dushmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
207
208
209
}

__host__ inline void shmem_barrier_all() {
lijian6's avatar
lijian6 committed
210
    dushmem_barrier_all();
211
212
213
}

__device__ inline void shmem_device_barrier_all() {
lijian6's avatar
lijian6 committed
214
    dushmem_barrier_all();
215
216
217
}

__device__ inline void shmem_barrier(shmem_team_t team) {
lijian6's avatar
lijian6 committed
218
    void(dushmem_barrier(team));
219
220
221
}

__host__ inline int shmem_my_pe(){
lijian6's avatar
lijian6 committed
222
    return dushmem_my_pe();
223
224
225
}

__host__ inline void shmem_free(void *ptr){
lijian6's avatar
lijian6 committed
226
    dushmem_free(ptr);
227
228
229
}

__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
lijian6's avatar
lijian6 committed
230
    return dushmem_align(size, alignment);
231
232
233
}

__host__ inline void shmem_finalize() {
lijian6's avatar
lijian6 committed
234
    dushmem_finalize();
235
236
237
}

__host__ inline void shmem_team_destroy(shmem_team_t team) {
lijian6's avatar
lijian6 committed
238
    dushmem_team_destroy(team);
239
240
241
}

__device__ inline void shmem_fence() {
lijian6's avatar
lijian6 committed
242
    dushmem_fence();
243
244
245
246
}

__device__ inline void shmem_int_put_nbi(
    int *dest, const int *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
247
    dushmem_int_put_nbi(dest, source, nelems, pe);
248
249
250
251
}

__device__ inline void shmemx_int_put_nbi_warp(
    int *dest, const int *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
252
    dushmemx_int_put_nbi_warp(dest, source, nelems, pe);
253
254
255
256
}

__device__ inline void shmemx_int8_put_nbi_warp(
    signed char *dest, const signed char *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
257
    dushmemx_int8_put_nbi_warp(dest, source, nelems, pe);
258
259
260
261
}

__device__ inline void shmem_signal_op_add(
    uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
262
    dushmemx_signal_op(dest, value, DUSHMEM_SIGNAL_ADD, pe);
263
264
265
266
}

__device__ inline void shmem_ulong_atomic_add(
    uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
267
    dushmem_ulong_atomic_add(dest, value, pe);
268
269
270
271
}

__device__ inline void shmem_long_atomic_add(
    long *dest, long value, int pe) {
lijian6's avatar
lijian6 committed
272
273
    // dushmem_##Name##_atomic_add(dest, value, pe);
    dushmem_long_atomic_add(dest, value, pe);
274
275
276
277
278
279
280
}

#endif

} // namespace deep_ep::internode

#endif