local_exchange.cuh 814 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
#include "stream_manager.h"
#include "utils/helper_cuda.h"
3
#include "utils/fmoe_utils.h"
Rick Ho's avatar
Rick Ho committed
4

5
6
7
8
9
10
11
12
__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
        size_t numel, size_t topk) {
    size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < numel) {
        long gate_idx = gate[idx];
        if (gate_idx > -1) {
            int p = atomicSub(cum_count + gate_idx, 1);
13
            pos[p - 1] = (long)idx;
14
15
16
17
18
19
20
21
22
23
        }
    }
}

void fmoe_cuda_assign_pos_impl(
        int* cum_count, const long* gate, long* pos,
        const size_t batch_size, const size_t topk,
        CudaStreamManager* smgr) {
    size_t numel = batch_size * topk;
    assign_pos_kernel
24
25
        <<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
        (cum_count, gate, pos, numel, topk);
26
27
    smgr->sync(1);
}