local_exchange.cuh 2.2 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
#include "stream_manager.h"
#include "utils/helper_cuda.h"
3
#include "utils/fmoe_utils.h"
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos, 
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * pos[blockIdx.x];
    oubuf += wid * blockIdx.x;
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
}

template <typename scalar_t>
void fmoe_cuda_local_scatter_impl(
        const scalar_t* input,
        const long* d_pos,
        scalar_t* input_buf,
        const long batch_size,
        const long in_feat, 
        CudaStreamManager* smgr) {
    batch_scatter_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
                input_buf); 
    smgr->sync(1);
}

template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos, 
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * blockIdx.x;
    oubuf += wid * pos[blockIdx.x];
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
}

template <typename scalar_t>
void fmoe_cuda_local_gather_impl(
        const scalar_t* output_buf,
        const long* d_pos,
        scalar_t* output,
        const size_t batch_size,
        const size_t out_feat,
        CudaStreamManager* smgr) {
    batch_gather_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
                output); 
    smgr->sync(1);
}
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
        size_t numel, size_t topk) {
    size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < numel) {
        long gate_idx = gate[idx];
        if (gate_idx > -1) {
            int p = atomicSub(cum_count + gate_idx, 1);
            pos[p] = (long)idx;
        }
    }
}

void fmoe_cuda_assign_pos_impl(
        int* cum_count, const long* gate, long* pos,
        const size_t batch_size, const size_t topk,
        CudaStreamManager* smgr) {
    size_t numel = batch_size * topk;
    assign_pos_kernel
        <<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>(cum_count, gate, pos,
                numel, topk);
    smgr->sync(1);
}