utils.cuh 14.5 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
#pragma once
lijian6's avatar
lijian6 committed
2
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
3
4
#include "exception.cuh"

lijian6's avatar
lijian6 committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC)                  \
    {                                                                                              \
        constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR);                                   \
        typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type                         \
             unrolled_values[(UNROLL_FACTOR)];                                                     \
        auto __src = (SRC);                                                                        \
        auto __dst = (DST);                                                                        \
        for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) {   \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j)                      \
                unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize);                     \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j)                      \
                ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]);                      \
        }                                                                                          \
        {                                                                                          \
            int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID);                               \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) {                    \
                if (__i + __j * kWarpSize < (N)) {                                                 \
                    unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize);                 \
                }                                                                                  \
            }                                                                                      \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) {                    \
                if (__i + __j * kWarpSize < (N)) {                                                 \
                    ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]);                  \
                }                                                                                  \
            }                                                                                      \
        }                                                                                          \
    }

#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC)         \
    {                                                                                              \
        constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR);                           \
        typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type                         \
             unrolled_values[(UNROLL_FACTOR)];                                                     \
        auto __src = (SRC);                                                                        \
        auto __dst = (DST);                                                                        \
        for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) {   \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j)                      \
                unrolled_values[__j] = LD_FUNC(__src + __i + __j * kEmulatedWarpSize);             \
            _Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j)                      \
                ST_FUNC(__dst + __i + __j * kEmulatedWarpSize, unrolled_values[__j]);              \
        }                                                                                          \
        for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N);                   \
             __i += kEmulatedWarpSize)                                                             \
            ST_FUNC(__dst + __i, LD_FUNC(__src + __i));                                            \
    }
// HELPER FUNCTIONS
// #####################################################################################

template <typename T>
__device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWarpSize,
                                      uint64_t shfl_sync_mask = kFullWarpMask) {
    return __shfl_xor(val, laneMask, width);
Chenggang Zhao's avatar
Chenggang Zhao committed
57
58
}

lijian6's avatar
lijian6 committed
59
60
61
62
63
64
65
66
67
68
__device__ __forceinline__ int
shfl_sync(const int val, int srcLane = 0, int width = kWarpSize,
          uint64_t shfl_sync_mask = kFullWarpMask) { // Let compiler deduce type
    return __shfl(val, srcLane, width);
}

__device__ __forceinline__ int __any_sync(uint64_t mask, int predicate) {
    uint64_t predicate_bit_pattern = __ballot(predicate);
    return (predicate_bit_pattern & mask) > 0;
}
Chenggang Zhao's avatar
Chenggang Zhao committed
69

lijian6's avatar
lijian6 committed
70
71
72
73
__device__ __forceinline__ int __all_sync(uint64_t mask, int predicate) {
    uint64_t predicate_bit_pattern = __ballot(predicate);
    return (~predicate_bit_pattern & mask) == 0;
}
Chenggang Zhao's avatar
Chenggang Zhao committed
74

lijian6's avatar
lijian6 committed
75
76
77
78
79
80
__device__ __forceinline__ void syncwarp() {
    __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
    __builtin_amdgcn_wave_barrier();
    __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
// ######################################################################################################
81

lijian6's avatar
lijian6 committed
82
namespace deep_ep {
83

lijian6's avatar
lijian6 committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
template <int kBytes> struct VecInt {};
template <> struct VecInt<1> {
    using vec_t = int8_t;
};
template <> struct VecInt<2> {
    using vec_t = int16_t;
};
template <> struct VecInt<4> {
    using vec_t = int;
};
template <> struct VecInt<8> {
    using vec_t = int64_t;
};
template <> struct VecInt<16> {
    using native_int4 = int __attribute__((ext_vector_type(4)));
    using vec_t       = native_int4;
100
101
};

Chenggang Zhao's avatar
Chenggang Zhao committed
102
__device__ __forceinline__ void trap() {
lijian6's avatar
lijian6 committed
103
    abort();
Chenggang Zhao's avatar
Chenggang Zhao committed
104
105
106
}

__device__ __forceinline__ void memory_fence() {
lijian6's avatar
lijian6 committed
107
108

    __threadfence_system();
Chenggang Zhao's avatar
Chenggang Zhao committed
109
110
111
}

__device__ __forceinline__ void memory_fence_gpu() {
lijian6's avatar
lijian6 committed
112
    __threadfence();
Chenggang Zhao's avatar
Chenggang Zhao committed
113
114
115
}

__device__ __forceinline__ void memory_fence_cta() {
lijian6's avatar
lijian6 committed
116
    __threadfence_block();
Chenggang Zhao's avatar
Chenggang Zhao committed
117
118
}

lijian6's avatar
lijian6 committed
119
120
__device__ __forceinline__ void st_relaxed_sys_global(int *ptr, int val) {
    __builtin_nontemporal_store(val, ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
121
122
}

lijian6's avatar
lijian6 committed
123
124
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
    __hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
125
126
}

lijian6's avatar
lijian6 committed
127
128
129
130
131
132
133
134
135
136
137
138
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
    __hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_WORKGROUP);
}

__device__ __forceinline__ int ld_relaxed_sys_global(const int *ptr) {
    int res = __builtin_nontemporal_load(ptr);
    return res;
}
__device__ __forceinline__ int ld_relaxed_sys_global(const uint64_t *ptr) {
    uint64_t ret;
    ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
    return ret;
Chenggang Zhao's avatar
Chenggang Zhao committed
139
140
141
142
}

__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
    int ret;
lijian6's avatar
lijian6 committed
143
    ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
144
145
146
147
148
    return ret;
}

__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
    uint64_t ret;
lijian6's avatar
lijian6 committed
149
    ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
150
151
152
153
154
    return ret;
}

__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
    int ret;
lijian6's avatar
lijian6 committed
155
    ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
156
157
158
    return ret;
}

lijian6's avatar
lijian6 committed
159
__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) {
Chenggang Zhao's avatar
Chenggang Zhao committed
160
    int ret;
lijian6's avatar
lijian6 committed
161
162
163
    // ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
    //                              __HIP_MEMORY_SCOPE_AGENT);
    ret = atomicAdd((int*)ptr, value);
Chenggang Zhao's avatar
Chenggang Zhao committed
164
165
166
167
168
    return ret;
}

__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
    int ret;
lijian6's avatar
lijian6 committed
169
    ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
Chenggang Zhao's avatar
Chenggang Zhao committed
170
171
172
    return ret;
}

lijian6's avatar
lijian6 committed
173
__device__ __forceinline__ int ld_volatile_global(const volatile int *ptr) {
Chenggang Zhao's avatar
Chenggang Zhao committed
174
    int ret;
lijian6's avatar
lijian6 committed
175
    ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
176
177
178
    return ret;
}

lijian6's avatar
lijian6 committed
179
__device__ __forceinline__ float ld_volatile_global(const volatile float *ptr) {
Chenggang Zhao's avatar
Chenggang Zhao committed
180
    float ret;
lijian6's avatar
lijian6 committed
181
    ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
182
183
184
    return ret;
}

lijian6's avatar
lijian6 committed
185
__device__ __forceinline__ int64_t ld_volatile_global(const volatile int64_t *ptr) {
Chenggang Zhao's avatar
Chenggang Zhao committed
186
    int64_t ret;
lijian6's avatar
lijian6 committed
187
    ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
188
189
190
    return ret;
}

lijian6's avatar
lijian6 committed
191
__device__ __forceinline__ int64_t ld_volatile_global(const volatile uint64_t *ptr) {
Chenggang Zhao's avatar
Chenggang Zhao committed
192
    int64_t ret;
lijian6's avatar
lijian6 committed
193
    ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
Chenggang Zhao's avatar
Chenggang Zhao committed
194
195
196
    return ret;
}

lijian6's avatar
lijian6 committed
197
198
199
200
201
template <typename dtype_t> 
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
    using T  = typename VecInt<sizeof(dtype_t)>::vec_t;
    auto ret = __builtin_nontemporal_load(reinterpret_cast<const T *>(ptr));
    return *reinterpret_cast<dtype_t *>(&ret);
Chenggang Zhao's avatar
Chenggang Zhao committed
202
203
}

lijian6's avatar
lijian6 committed
204
////////////////// used in ibgda
Chenggang Zhao's avatar
Chenggang Zhao committed
205
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
lijian6's avatar
lijian6 committed
206
207
    uint8_t *non_const_ptr = const_cast<uint8_t *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
208
209
210
}

__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
lijian6's avatar
lijian6 committed
211
212
    uint16_t *non_const_ptr = const_cast<uint16_t *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
213
214
215
}

__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
lijian6's avatar
lijian6 committed
216
217
    uint32_t *non_const_ptr = const_cast<uint32_t *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
218
219
220
}

__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
lijian6's avatar
lijian6 committed
221
222
    int *non_const_ptr = const_cast<int *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
223
224
225
}

__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
lijian6's avatar
lijian6 committed
226
227
228
229
230
    int4 *non_const_ptr = const_cast<int4 *>(ptr);
    non_const_ptr->x    = val.x;
    non_const_ptr->y    = val.y;
    non_const_ptr->z    = val.z;
    non_const_ptr->w    = val.w;
Chenggang Zhao's avatar
Chenggang Zhao committed
231
232
233
}

__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
lijian6's avatar
lijian6 committed
234
235
    int *non_const_ptr = const_cast<int *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
236
237
238
}

__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
lijian6's avatar
lijian6 committed
239
240
    uint32_t *non_const_ptr = const_cast<uint32_t *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
241
242
243
}

__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
lijian6's avatar
lijian6 committed
244
245
    uint64_t *non_const_ptr = const_cast<uint64_t *>(ptr);
    __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
Chenggang Zhao's avatar
Chenggang Zhao committed
246
247
}

lijian6's avatar
lijian6 committed
248
// TODO:: apply "st.global.L1::no_allocate" in ROCM
Chenggang Zhao's avatar
Chenggang Zhao committed
249
template <typename dtype_t>
lijian6's avatar
lijian6 committed
250
251
252
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) {
    st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t *>(ptr),
                 *reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t *>(&value));
Chenggang Zhao's avatar
Chenggang Zhao committed
253
254
}

lijian6's avatar
lijian6 committed
255
256
257
template <> __device__ __forceinline__ void st_na_global(const int *ptr, const int &value) {
    int *non_const_ptr = const_cast<int *>(ptr);
    *non_const_ptr     = value;
258
259
}

lijian6's avatar
lijian6 committed
260
261
262
template <> __device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t &value) {
    int64_t *non_const_ptr = const_cast<int64_t *>(ptr);
    *non_const_ptr         = value;
Chenggang Zhao's avatar
Chenggang Zhao committed
263
264
}

lijian6's avatar
lijian6 committed
265
266
267
template <> __device__ __forceinline__ void st_na_global(const float *ptr, const float &value) {
    float *non_const_ptr = const_cast<float *>(ptr);
    *non_const_ptr       = value;
Chenggang Zhao's avatar
Chenggang Zhao committed
268
269
}

lijian6's avatar
lijian6 committed
270
271
272
template <> __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4 &value) {
    int4 *non_const_ptr = const_cast<int4 *>(ptr);
    *non_const_ptr      = value;
273
274
}

Chenggang Zhao's avatar
Chenggang Zhao committed
275
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
lijian6's avatar
lijian6 committed
276
277
278
279
                                                       int &token_start_idx, int &token_end_idx) {
    int num_tokens_per_sm = DIVUP(num_tokens, num_sms);
    token_start_idx       = min(num_tokens_per_sm * sm_id, num_tokens);
    token_end_idx         = min(token_start_idx + num_tokens_per_sm, num_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
280
281
282
}

template <typename dtype_t>
lijian6's avatar
lijian6 committed
283
__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
284
    EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
lijian6's avatar
lijian6 committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    auto send_int_values = reinterpret_cast<int *>(&ptr);
    int  recv_int_values[sizeof(dtype_t) / sizeof(int)];
#pragma unroll
    for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++i)
        recv_int_values[i] = shfl_sync(send_int_values[i], src_lane_idx);
    return *reinterpret_cast<dtype_t *>(recv_int_values);
}

__forceinline__ __device__ int warp_reduce_sum(int value) {
    if constexpr (kWarpSize == 64)
        value += shfl_xor<int>(value, 32);
    value += shfl_xor<int>(value, 16);
    value += shfl_xor<int>(value, 8);
    value += shfl_xor<int>(value, 4);
    value += shfl_xor<int>(value, 2);
    value += shfl_xor<int>(value, 1);
    return value;
Shifang Xu's avatar
Shifang Xu committed
302
303
}

lijian6's avatar
lijian6 committed
304
305
306
__forceinline__ __device__ int get_lane_id() {
    int lane_id = threadIdx.x % kWarpSize;
    return lane_id;
Shifang Xu's avatar
Shifang Xu committed
307
308
}

309
template <int kNumRanks, bool kSyncOnly = false>
lijian6's avatar
lijian6 committed
310
__forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
311
312
    auto thread_id = static_cast<int>(threadIdx.x);

lijian6's avatar
lijian6 committed
313
314
    // For non-sync-only cases, the memory operations by other threads in the block must be visible
    // to the `sys` scope
315
316
317
318
319
    if constexpr (not kSyncOnly) {
        memory_fence();
        __syncthreads();
    }

320
    // Add self-ranks, sub other ranks
Chenggang Zhao's avatar
Chenggang Zhao committed
321
    if (thread_id < kNumRanks) {
322
323
324
325
326
327
328
329
        atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
        atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
    }
    EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);

    // Check timeout
    auto start_time = clock64();
    while (true) {
lijian6's avatar
lijian6 committed
330
331
332
        auto value =
            thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;
        if (__all_sync(kFullWarpMask, value <= 0))
333
334
            break;

Chenggang Zhao's avatar
Chenggang Zhao committed
335
        if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) {
lijian6's avatar
lijian6 committed
336
337
            printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank,
                   thread_id, value);
338
339
            trap();
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
340
    }
341
    __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
342
343
}
} // namespace deep_ep