local_exchange.cuh 2.79 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "stream_manager.h"
#include "utils/helper_cuda.h"

template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
        const long* offset, const scalar_t** ptrs) { 
    size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
    if (idx < n) {
        ptrs[idx] = base + stride * offset[idx];
    }
}

template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos, 
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * pos[blockIdx.x];
    oubuf += wid * blockIdx.x;
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
}

void fmoe_cuda_expert_count_impl(
        const int* d_gate,
        int* expert_count,
        int* d_pos,
        const size_t num_expert,
        const size_t batch_size) {
    int *gate = new int[batch_size];
    int *expert_ptr = new int[num_expert];
    memset(expert_count, 0, sizeof(int) * num_expert);

    checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
                cudaMemcpyDeviceToHost));

    for (int i = 0; i < batch_size; ++i) {
        ++expert_count[gate[i]];
    }
    expert_ptr[0] = 0;
    for (int i = 1; i < num_expert; ++i) {
        expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
    }

    int *pos = new int[batch_size];

    for (int i = 0; i < batch_size; ++i) {
        pos[i] = expert_ptr[gate[i]]++;
    }
    for (int i = num_expert - 1; i > 0; --i) {
        expert_ptr[i] = expert_ptr[i - 1];
    }
    expert_ptr[0] = 0;
    checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
                cudaMemcpyHostToDevice));
    delete [] gate;
    delete [] expert_ptr;
}

template <typename scalar_t>
void fmoe_cuda_local_scatter_impl(
        const scalar_t* input,
        const long* d_pos,
        scalar_t* input_buf,
        const long batch_size,
        const long in_feat, 
        CudaStreamManager* smgr) {
    batch_scatter_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
                input_buf); 
    smgr->sync(1);
}

template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos, 
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * blockIdx.x;
    oubuf += wid * pos[blockIdx.x];
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
}

template <typename scalar_t>
void fmoe_cuda_local_gather_impl(
        const scalar_t* output_buf,
        const long* d_pos,
        scalar_t* output,
        const size_t batch_size,
        const size_t out_feat,
        CudaStreamManager* smgr) {
    batch_gather_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
                output); 
    smgr->sync(1);
}