balancing.cu 5.35 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
#include <cstdio>
Rick Ho's avatar
Rick Ho committed
2
#include "balancing.cuh"
Rick Ho's avatar
Rick Ho committed
3
#include "global_exchange.h"
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
#include <torch/extension.h>

/* 
 * note that due to limit of cuda atomic operator, capacity should be int32
 */
9
torch::Tensor _limit_by_capacity(
Rick Ho's avatar
Rick Ho committed
10
11
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_worker) {
Rick Ho's avatar
Rick Ho committed
12
13
    CHECK_INPUT(expert_count);
    CHECK_INPUT(capacity);
Rick Ho's avatar
Rick Ho committed
14
15
    auto expert_count_ack = torch::empty_like(expert_count);
    auto smgr = getCudaStreamManager(expert_count.device().index());
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
    fmoe_cuda_limit_by_capacity_impl(
            expert_count.data_ptr<long>(),
            capacity.data_ptr<int>(),
            expert_count_ack.data_ptr<long>(),
            n_expert, n_worker, smgr);
21
    return expert_count_ack;
Rick Ho's avatar
Rick Ho committed
22
23
}

24
torch::Tensor _prune_gate_by_capacity(
Rick Ho's avatar
Rick Ho committed
25
26
27
28
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker) {
    auto smgr = getCudaStreamManager(expert_count.device().index());
    auto batch_size = gate_idx.numel();
29
30
31
32
    auto opt = torch::TensorOptions()
        .dtype(gate_idx.dtype())
        .device(gate_idx.device());
    auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
Rick Ho's avatar
Rick Ho committed
33
34
    fmoe_cuda_prune_gate_by_capacity_impl(
            gate_idx.data_ptr<long>(),
35
            new_gate_idx.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
36
37
            expert_count.data_ptr<int>(),
            batch_size, n_expert, n_worker, smgr);
38
    return new_gate_idx;
Rick Ho's avatar
Rick Ho committed
39
}
Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

template<class T>
T* _cudamalloc(size_t sz) {
    T* dptr;
    cudaMalloc(&dptr, sz * sizeof(T));
    return dptr;
}

template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
    cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice);
    return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
    T* dptr = _cudamalloc<T>(sz);
    return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
    cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost);
    return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
    T* hptr = new T[sz];
    return _d2h(dptr, hptr, sz);
}

#ifdef FMOE_USE_NCCL

#include <nccl.h>

#define UPDATE_COUNTERS(__count__) { \
    if (i == rank) { \
        lec[j] += (__count__); \
    } \
    if (j == rank) { \
        gec[i] += (__count__); \
        cap -= (__count__); \
    } \
}

std::vector<torch::Tensor> _swipe_once(
        torch::Tensor gate_idx, torch::Tensor capacity,
        long n_expert, long n_worker, long bias) {
    auto device_idx = gate_idx.device().index();
    auto smgr = getCudaStreamManager(device_idx);
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    cudaSetDevice(device_idx);

Rick Ho's avatar
Rick Ho committed
92
93
    auto capacity_new = capacity.clone();
    auto cap = capacity_new.item<long>();
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
100
101
102

    long batch_size = gate_idx.size(0);
    auto gate_idx_cpu = gate_idx.cpu();
    long* gidx = gate_idx_cpu.data_ptr<long>();

    /* Local count and exchange */
    long *lec = new long[n_worker];
    memset(lec, 0, n_worker * sizeof(long));
    for (long i = 0; i < batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
103
        ++lec[gidx[i] / n_expert];
Rick Ho's avatar
Rick Ho committed
104
105
106
    }
    long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
    fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
Rick Ho's avatar
Rick Ho committed
107
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
108
    long *gec = _d2h(d_gec, n_worker);
Rick Ho's avatar
Rick Ho committed
109
110
111
112

    /* Limit number of incoming samples */
    long *drop_count = new long[n_worker];
    memset(drop_count, 0, n_worker * sizeof(long));
Rick Ho's avatar
Rick Ho committed
113
    for (long i = 0; i < n_worker; ++i) {
Rick Ho's avatar
Rick Ho committed
114
115
116
117
118
119
120
121
122
123
124
125
        if (cap >= gec[i]) {
            drop_count[i] = 0;
            cap -= gec[i];
        } else {
            drop_count[i] = gec[i] - cap;
            gec[i] = cap;
            cap = 0;
        }
    }

    /* Send limit information back */
    _h2d(gec, d_gec, n_worker);
Rick Ho's avatar
Rick Ho committed
126
    fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
Rick Ho's avatar
Rick Ho committed
127
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
128
    _d2h(d_lec, lec, n_worker);
Rick Ho's avatar
Rick Ho committed
129
130
131

    auto d_dropcount = _h2d(drop_count, n_worker);
    ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
Rick Ho's avatar
Rick Ho committed
132
133
            smgr->ncclcomm, smgr->torchStream());
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
134
135
136
    _d2h(d_dropcount, drop_count, n_worker);

    auto d_gcap = _cudamalloc<long>(n_worker);
Rick Ho's avatar
Rick Ho committed
137
    _h2d(&cap, d_gcap + rank, 1);
Rick Ho's avatar
Rick Ho committed
138
    ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
Rick Ho's avatar
Rick Ho committed
139
140
            smgr->ncclcomm, smgr->torchStream());
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
141
142
    auto gcap = _d2h(d_gcap, n_worker);

Rick Ho's avatar
Rick Ho committed
143
    /* Re-assign and update counters */
Rick Ho's avatar
Rick Ho committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    for (long i = 0, j = 0; i < n_worker; ++i) {
        while (drop_count[i] > 0) {
            if (drop_count[i] > gcap[j]) {
                drop_count[i] -= gcap[j];
                UPDATE_COUNTERS(gcap[j]);
                ++j;
            } else {
                gcap[j] -= drop_count[i];
                UPDATE_COUNTERS(drop_count[i]);
                break;
            }
        }
    }
    for (long i = 0; i < batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
158
        auto widx = gidx[i] / n_expert;
Rick Ho's avatar
Rick Ho committed
159
160
161
162
163
164
165
166
167
168
169
        if (lec[widx] > 0) {
            --lec[widx];
        } else {
            gidx[i] = -1;
        }
    }
    for (long i = 0, k = 0; i < batch_size; ++i) {
        if (gidx[i] != -1) {
            continue;
        }
        for (; lec[k] == 0; ++k);
Rick Ho's avatar
Rick Ho committed
170
171
        --lec[k];
        gidx[i] = k * n_expert + bias;
Rick Ho's avatar
Rick Ho committed
172
    }
Rick Ho's avatar
Rick Ho committed
173
    *capacity_new.data_ptr<long>() = cap;
Rick Ho's avatar
Rick Ho committed
174

Rick Ho's avatar
Rick Ho committed
175
176
177
178
179
180
181
182
183
184
185
    delete [] drop_count;
    delete [] lec;
    delete [] gec;
    delete [] gcap;

    cudaFree(d_dropcount);
    cudaFree(d_lec);
    cudaFree(d_gec);
    cudaFree(d_gcap);

    return {gate_idx_cpu, capacity_new};
Rick Ho's avatar
Rick Ho committed
186
187
188
189
190
}

#undef UPDATE_COUNTERS

#endif