balancing.cu 4.95 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
#include "balancing.cuh"
Rick Ho's avatar
Rick Ho committed
2
#include "global_exchange.h"
Rick Ho's avatar
Rick Ho committed
3
4
5
6
7
#include <torch/extension.h>

/* 
 * note that due to limit of cuda atomic operator, capacity should be int32
 */
8
torch::Tensor _limit_by_capacity(
Rick Ho's avatar
Rick Ho committed
9
10
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_worker) {
Rick Ho's avatar
Rick Ho committed
11
12
    CHECK_INPUT(expert_count);
    CHECK_INPUT(capacity);
Rick Ho's avatar
Rick Ho committed
13
14
    auto expert_count_ack = torch::empty_like(expert_count);
    auto smgr = getCudaStreamManager(expert_count.device().index());
Rick Ho's avatar
Rick Ho committed
15
16
17
18
19
    fmoe_cuda_limit_by_capacity_impl(
            expert_count.data_ptr<long>(),
            capacity.data_ptr<int>(),
            expert_count_ack.data_ptr<long>(),
            n_expert, n_worker, smgr);
20
    return expert_count_ack;
Rick Ho's avatar
Rick Ho committed
21
22
}

23
torch::Tensor _prune_gate_by_capacity(
Rick Ho's avatar
Rick Ho committed
24
25
26
27
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker) {
    auto smgr = getCudaStreamManager(expert_count.device().index());
    auto batch_size = gate_idx.numel();
28
29
30
31
    auto opt = torch::TensorOptions()
        .dtype(gate_idx.dtype())
        .device(gate_idx.device());
    auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
Rick Ho's avatar
Rick Ho committed
32
33
    fmoe_cuda_prune_gate_by_capacity_impl(
            gate_idx.data_ptr<long>(),
34
            new_gate_idx.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
35
36
            expert_count.data_ptr<int>(),
            batch_size, n_expert, n_worker, smgr);
37
    return new_gate_idx;
Rick Ho's avatar
Rick Ho committed
38
}
Rick Ho's avatar
Rick Ho committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

template<class T>
T* _cudamalloc(size_t sz) {
    T* dptr;
    cudaMalloc(&dptr, sz * sizeof(T));
    return dptr;
}

template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
    cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice);
    return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
    T* dptr = _cudamalloc<T>(sz);
    return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
    cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost);
    return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
    T* hptr = new T[sz];
    return _d2h(dptr, hptr, sz);
}

#ifdef FMOE_USE_NCCL

#include <nccl.h>

#define UPDATE_COUNTERS(__count__) { \
    if (i == rank) { \
        lec[j] += (__count__); \
    } \
    if (j == rank) { \
        gec[i] += (__count__); \
        cap -= (__count__); \
    } \
}

std::vector<torch::Tensor> _swipe_once(
        torch::Tensor gate_idx, torch::Tensor capacity,
        long n_expert, long n_worker, long bias) {
    auto device_idx = gate_idx.device().index();
    auto smgr = getCudaStreamManager(device_idx);
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    cudaSetDevice(device_idx);

    auto cap = capacity.item<long>();

    long batch_size = gate_idx.size(0);
    auto gate_idx_cpu = gate_idx.cpu();
    long* gidx = gate_idx_cpu.data_ptr<long>();

    /* Local count and exchange */
    long *lec = new long[n_worker];
    memset(lec, 0, n_worker * sizeof(long));
    for (long i = 0; i < batch_size; ++i) {
        ++lec[gidx[i] % n_expert];
    }
    long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
    fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
    long *gec = _d2h(d_gec, n_expert);

    /* Limit number of incoming samples */
    long *drop_count = new long[n_worker];
    memset(drop_count, 0, n_worker * sizeof(long));
    for (long i = 0; i < n_expert; ++i) {
        if (cap >= gec[i]) {
            drop_count[i] = 0;
            cap -= gec[i];
        } else {
            drop_count[i] = gec[i] - cap;
            gec[i] = cap;
            cap = 0;
        }
    }

    /* Send limit information back */
    _h2d(gec, d_gec, n_worker);
    fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_expert, smgr);
    _d2h(d_lec, lec, n_expert);

    auto d_dropcount = _h2d(drop_count, n_worker);
    ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
            smgr->ncclcomm, smgr->stream());
    _d2h(d_dropcount, drop_count, n_worker);

    auto d_gcap = _cudamalloc<long>(n_worker);
    _h2d(d_gcap + rank, &cap, n_worker);
    ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
            smgr->ncclcomm, smgr->stream());
    auto gcap = _d2h(d_gcap, n_worker);

    /* Re-assign counts */
    for (long i = 0, j = 0; i < n_worker; ++i) {
        while (drop_count[i] > 0) {
            if (drop_count[i] > gcap[j]) {
                drop_count[i] -= gcap[j];
                UPDATE_COUNTERS(gcap[j]);
                ++j;
            } else {
                gcap[j] -= drop_count[i];
                UPDATE_COUNTERS(drop_count[i]);
                break;
            }
        }
    }
    for (long i = 0; i < batch_size; ++i) {
        auto widx = gidx[i] % n_expert;
        if (lec[widx] > 0) {
            --lec[widx];
        } else {
            gidx[i] = -1;
        }
    }
    for (long i = 0, k = 0; i < batch_size; ++i) {
        if (gidx[i] != -1) {
            continue;
        }
        for (; lec[k] == 0; ++k);
        --lec[gidx[i] = k * n_expert + bias];
    }

    return {gate_idx_cpu, capacity};
}

#undef UPDATE_COUNTERS

#endif