local_exchange.cuh 2.25 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "stream_manager.h"
#include "utils/helper_cuda.h"
#include "utils/fmoe_utils.h"

__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
        size_t numel, size_t topk) {
    size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < numel) {
        long gate_idx = gate[idx];
        if (gate_idx > -1) {
            int p = atomicSub(cum_count + gate_idx, 1);
            pos[p - 1] = (long)idx;
        }
    }
}

void fmoe_cuda_assign_pos_impl(
        int* cum_count, const long* gate, long* pos,
        const size_t batch_size, const size_t topk,
        CudaStreamManager* smgr) {
    size_t numel = batch_size * topk;
    assign_pos_kernel
        <<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
        (cum_count, gate, pos, numel, topk);
    smgr->sync(1);
}

#define PERTHREAD_EXPERTS 256

#ifdef FMOE_USE_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

__global__
void expert_count_kernel(const long* gate_idx, int* expert_count,
        const size_t batch_size, const size_t n_expert) {
    int res_tmp[PERTHREAD_EXPERTS] = {0};
    long expert_min = blockIdx.x * PERTHREAD_EXPERTS;
    long expert_max = expert_min + PERTHREAD_EXPERTS;
    if (expert_max > n_expert) {
        expert_max = n_expert;
    }
    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
        long idx = gate_idx[i];
        if (idx == -1) {
            continue;
        }
        if (idx < expert_min || idx >= expert_max) {
            continue;
        }
        res_tmp[idx - expert_min] += 1;
    }
    for (int i = expert_min; i < expert_max; ++i) {
        int x = res_tmp[i - expert_min];
#pragma unroll
        for (int j = 1; j < WARP_SIZE; j <<= 1) {
#ifdef FMOE_USE_HIP
            x = x + __shfl_down(x, j);
#else
            x = x + __shfl_down_sync(-1u, x, j);
#endif
        }
        if (threadIdx.x % WARP_SIZE == 0) {
            atomicAdd(expert_count + i, x);
        }
    }
}

void fmoe_cuda_expert_count_impl(
        const long* gate_idx, int* expert_count,
        const size_t batch_size, const size_t n_expert,
        CudaStreamManager* smgr) {
    expert_count_kernel
        <<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>>
        (gate_idx, expert_count, batch_size, n_expert);
    smgr->sync(1);
}