shmem_wrapper.cuh 8.91 KB
Newer Older
1
2
#pragma once
/*
lijian6's avatar
lijian6 committed
3
 * Temporary wrapper for for platform specific DUSHMEM and rocSHMEM functions.
4
 * Once hipify or hipify-torch fully supports this mapping, this file has to be
lijian6's avatar
lijian6 committed
5
 * removed and according dushmem* functions restored.
6
7
8
9
10
 */
#ifndef DISABLE_ROCSHMEM

#include "configs.cuh"

lijian6's avatar
lijian6 committed
11
#ifndef FORCE_DUSHMEM_API
12
13
14
15
16
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
lijian6's avatar
lijian6 committed
17
#include <device_host_transport/dushmem_common_ibgda.h>
18
#include <infiniband/mlx5dv.h>
lijian6's avatar
lijian6 committed
19
20
21
#include <dushmem.h>
#include <dushmemx.h>
#include <non_abi/device/threadgroup/dushmemi_common_device_defines.cuh>
22
23
24
25
26
#endif

namespace deep_ep::internode {

// rocSHMEM wrapper
lijian6's avatar
lijian6 committed
27
#ifndef FORCE_DUSHMEM_API
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
using shmem_team_t = rocshmem::rocshmem_team_t;
using shmem_team_config_t = rocshmem::rocshmem_team_config_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = rocshmem::ROCSHMEM_TEAM_INVALID;
inline shmem_team_t& EP_SHMEM_TEAM_WORLD = rocshmem::ROCSHMEM_TEAM_WORLD;

using shmemx_uniqueid_t = rocshmem::rocshmem_uniqueid_t;
using shmemx_init_attr_t = rocshmem::rocshmem_init_attr_t;

constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID;

__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
    return rocshmem::rocshmem_get_uniqueid(uid);
}

__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
    shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
    return rocshmem::rocshmem_set_attr_uniqueid_args(rank, nranks, uid, attr);
}

__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
    return rocshmem::rocshmem_init_attr(flags, attr);
}

__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
    int start, int stride, int size,
    const shmem_team_config_t *config,
    long config_mask, shmem_team_t *new_team) {
    return rocshmem::rocshmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
}

__host__ inline void shmem_barrier_all() {
    rocshmem::rocshmem_barrier_all();
}

__device__ inline void shmem_device_barrier_all() {
    rocshmem::rocshmem_barrier_all();
}

__device__ inline void shmem_barrier(shmem_team_t team) {
    rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, team);
}

__host__ inline int shmem_my_pe(){
    return rocshmem::rocshmem_my_pe();
}

__host__ inline void shmem_free(void *ptr){
    rocshmem::rocshmem_free(ptr);
}

__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
    auto alloc_size = ALIGN(size, alignment);
    return rocshmem::rocshmem_malloc(alloc_size);
}

__host__ inline void shmem_finalize() {
    rocshmem::rocshmem_finalize();
}

__host__ inline void shmem_team_destroy(shmem_team_t team) {
    rocshmem::rocshmem_team_destroy(team);
}

__device__ inline void shmem_fence() {
    rocshmem::rocshmem_fence();
}

__device__ inline void shmem_int_put_nbi(
    int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_int_put_nbi(dest, source, nelems, pe);
}

__device__ inline void shmemx_int_put_nbi_warp(
    int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_int_put_nbi_wave(dest, source, nelems, pe);
}

__device__ inline void shmemx_int8_put_nbi_warp(
    signed char *dest, const signed char *source, size_t nelems, int pe) {
    rocshmem::rocshmem_schar_put_nbi_wave(dest, source, nelems, pe);
}

lijian6's avatar
lijian6 committed
110
__device__ inline void shmem_signal_op_add(uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
111
112
113
    rocshmem::rocshmem_ulong_atomic_add(dest, value, pe);
}

114
115
116
117
118
__device__ inline void shmem_long_atomic_add(
    long *dest, long value, int pe) {
    rocshmem::rocshmem_long_atomic_add(dest, value, pe);
}

lishen's avatar
lishen committed
119
120
121
122
123
__device__ inline uint64_t shmem_get_p2p_ptr(void *dest, int rank, int dst_rank) {
    return rocshmem::rocshmem_get_p2p_ptr(dest, rank, dst_rank);
}

#if !defined(ROCM_DISABLE_MULTIQP)
124
125
126
127
__device__ inline void shmem_qp_quiet(int idx_qp) {
    rocshmem::rocshmem_quiet_dp(idx_qp);
}

128
129
130
131
132
133
134
135
136
137
138
__device__ inline void shmemx_int8_put_nbi_warp_dp(
    signed char *dest, const signed char *source, size_t nelems, int qp_idx, int pe) {
    rocshmem::rocshmem_schar_put_nbi_wave_dp(dest, source, nelems, qp_idx, pe);
}

__device__ inline void shmem_long_atomic_add_dp(
    long *dest, long value, int qp_idx, int pe) {
    rocshmem::rocshmem_long_atomic_add_dp(dest, value, qp_idx, pe);
}
#endif

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#if !defined(ROCM_DISABLE_CTX)
using shmem_ctx_t = rocshmem::rocshmem_ctx_t;

__device__ inline int shmem_wg_ctx_create(shmem_ctx_t *ctx) {
    return rocshmem::rocshmem_wg_ctx_create(0, ctx);
}

__device__ inline void shmem_wg_ctx_destroy(shmem_ctx_t *ctx) {
    rocshmem::rocshmem_wg_ctx_destroy(ctx);
}

__device__ inline void shmem_ctx_quiet(shmem_ctx_t ctx) {
    rocshmem::rocshmem_ctx_quiet(ctx);
}

__device__ inline void shmem_ctx_ulong_atomic_add(
    shmem_ctx_t ctx, uint64_t *dest, uint64_t value, int pe) {
    rocshmem::rocshmem_ctx_ulong_atomic_add(ctx, dest, value, pe);
}

__device__ inline void shmem_ctx_long_atomic_add(
    shmem_ctx_t ctx, long *dest, long value, int pe) {
    rocshmem::rocshmem_ctx_long_atomic_add(ctx, dest, value, pe);
}

__device__ inline void shmem_ctx_schar_put_nbi_warp(
    shmem_ctx_t ctx, signed char *dest, const signed char *source, size_t nelems, int pe) {
    rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx, dest, source, nelems, pe);
}

__device__ inline void shmem_ctx_int_put_nbi_warp(
    shmem_ctx_t ctx, int *dest, const int *source, size_t nelems, int pe) {
    rocshmem::rocshmem_ctx_int_put_nbi_wave(ctx, dest, source, nelems, pe);
}

#endif

#else

lijian6's avatar
lijian6 committed
178
// DUSHMEM wrapper
179
180
181
182
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif

lijian6's avatar
lijian6 committed
183
184
185
186
187
188
189
using shmem_team_t = dushmem_team_t;
using shmem_team_config_t = dushmem_team_config_t;
using shmemx_uniqueid_t = dushmemx_uniqueid_t;
using shmemx_init_attr_t = dushmemx_init_attr_t;
const shmem_team_t EP_SHMEM_TEAM_INVALID = DUSHMEM_TEAM_INVALID;
const shmem_team_t EP_SHMEM_TEAM_WORLD = DUSHMEM_TEAM_WORLD;
constexpr auto EP_SHMEMX_INIT_WITH_UNIQUEID = DUSHMEMX_INIT_WITH_UNIQUEID;
190
191
192


__host__ inline int shmemx_get_uniqueid(shmemx_uniqueid_t *uid) {
lijian6's avatar
lijian6 committed
193
    return dushmemx_get_uniqueid(uid);
194
195
196
197
}

__host__ inline int shmemx_set_attr_uniqueid_args(int rank, int nranks,
    shmemx_uniqueid_t *uid, shmemx_init_attr_t *attr) {
lijian6's avatar
lijian6 committed
198
    return dushmemx_set_attr_uniqueid_args(rank, nranks, uid, attr);
199
200
201
202
}


__host__ inline int shmemx_init_attr(unsigned int flags, shmemx_init_attr_t *attr) {
lijian6's avatar
lijian6 committed
203
    return dushmemx_init_attr(flags, attr);
204
205
206
207
208
209
}

__host__ inline int shmem_team_split_strided(shmem_team_t parent_team,
    int start, int stride, int size,
    const shmem_team_config_t *config,
    long config_mask, shmem_team_t *new_team) {
lijian6's avatar
lijian6 committed
210
    return dushmem_team_split_strided(parent_team, start, stride, size, config, config_mask, new_team);
211
212
213
}

__host__ inline void shmem_barrier_all() {
lijian6's avatar
lijian6 committed
214
    dushmem_barrier_all();
215
216
217
}

__device__ inline void shmem_device_barrier_all() {
lijian6's avatar
lijian6 committed
218
    dushmem_barrier_all();
219
220
221
}

__device__ inline void shmem_barrier(shmem_team_t team) {
lijian6's avatar
lijian6 committed
222
    void(dushmem_barrier(team));
223
224
225
}

__host__ inline int shmem_my_pe(){
lijian6's avatar
lijian6 committed
226
    return dushmem_my_pe();
227
228
229
}

__host__ inline void shmem_free(void *ptr){
lijian6's avatar
lijian6 committed
230
    dushmem_free(ptr);
231
232
233
}

__host__ inline void* shmem_align(const size_t alignment, const size_t size) {
lijian6's avatar
lijian6 committed
234
    return dushmem_align(size, alignment);
235
236
237
}

__host__ inline void shmem_finalize() {
lijian6's avatar
lijian6 committed
238
    dushmem_finalize();
239
240
241
}

__host__ inline void shmem_team_destroy(shmem_team_t team) {
lijian6's avatar
lijian6 committed
242
    dushmem_team_destroy(team);
243
244
245
}

__device__ inline void shmem_fence() {
lijian6's avatar
lijian6 committed
246
    dushmem_fence();
247
248
249
250
}

__device__ inline void shmem_int_put_nbi(
    int *dest, const int *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
251
    dushmem_int_put_nbi(dest, source, nelems, pe);
252
253
254
255
}

__device__ inline void shmemx_int_put_nbi_warp(
    int *dest, const int *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
256
    dushmemx_int_put_nbi_warp(dest, source, nelems, pe);
257
258
259
260
}

__device__ inline void shmemx_int8_put_nbi_warp(
    signed char *dest, const signed char *source, size_t nelems, int pe) {
lijian6's avatar
lijian6 committed
261
    dushmemx_int8_put_nbi_warp(dest, source, nelems, pe);
262
263
264
265
}

__device__ inline void shmem_signal_op_add(
    uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
266
    dushmemx_signal_op(dest, value, DUSHMEM_SIGNAL_ADD, pe);
267
268
269
270
}

__device__ inline void shmem_ulong_atomic_add(
    uint64_t *dest, uint64_t value, int pe) {
lijian6's avatar
lijian6 committed
271
    dushmem_ulong_atomic_add(dest, value, pe);
272
273
274
275
}

__device__ inline void shmem_long_atomic_add(
    long *dest, long value, int pe) {
lijian6's avatar
lijian6 committed
276
277
    // dushmem_##Name##_atomic_add(dest, value, pe);
    dushmem_long_atomic_add(dest, value, pe);
278
279
}

lishen's avatar
lishen committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
__device__ __forceinline__ uint64_t shmem_get_p2p_ptr(void *dest, int rank, int dst_rank) {
    // Local rank, no need for mapping
    if (rank == dst_rank)
        return reinterpret_cast<uint64_t>(dest);

    auto peer_base = __ldg(reinterpret_cast<uint64_t*>(dushmemi_device_state_d.peer_heap_base_p2p) + dst_rank);
    // RDMA connected
    if (peer_base == 0)
        return 0;

    // NVLink P2P is enabled
    return peer_base + (reinterpret_cast<uint64_t>(dest) - reinterpret_cast<uint64_t>(dushmemi_device_state_d.heap_base));
}

294
295
296
297
298
#endif

} // namespace deep_ep::internode

#endif