balancing.cu 6.01 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
#include <cstdio>
Rick Ho's avatar
Rick Ho committed
2
#include "balancing.cuh"
Rick Ho's avatar
Rick Ho committed
3
#include "global_exchange.h"
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
#include <torch/extension.h>

/* 
 * note that due to limit of cuda atomic operator, capacity should be int32
 */
9
torch::Tensor _limit_by_capacity(
Rick Ho's avatar
Rick Ho committed
10
11
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_worker) {
Rick Ho's avatar
Rick Ho committed
12
13
    CHECK_INPUT(expert_count);
    CHECK_INPUT(capacity);
Rick Ho's avatar
Rick Ho committed
14
15
    auto expert_count_ack = torch::empty_like(expert_count);
    auto smgr = getCudaStreamManager(expert_count.device().index());
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
    fmoe_cuda_limit_by_capacity_impl(
            expert_count.data_ptr<long>(),
            capacity.data_ptr<int>(),
            expert_count_ack.data_ptr<long>(),
            n_expert, n_worker, smgr);
21
    return expert_count_ack;
Rick Ho's avatar
Rick Ho committed
22
23
}

24
torch::Tensor _prune_gate_by_capacity(
Rick Ho's avatar
Rick Ho committed
25
26
27
28
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker) {
    auto smgr = getCudaStreamManager(expert_count.device().index());
    auto batch_size = gate_idx.numel();
29
30
31
32
    auto opt = torch::TensorOptions()
        .dtype(gate_idx.dtype())
        .device(gate_idx.device());
    auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
Rick Ho's avatar
Rick Ho committed
33
34
    fmoe_cuda_prune_gate_by_capacity_impl(
            gate_idx.data_ptr<long>(),
35
            new_gate_idx.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
36
37
            expert_count.data_ptr<int>(),
            batch_size, n_expert, n_worker, smgr);
38
    return new_gate_idx;
Rick Ho's avatar
Rick Ho committed
39
}
Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

template<class T>
T* _cudamalloc(size_t sz) {
    T* dptr;
    cudaMalloc(&dptr, sz * sizeof(T));
    return dptr;
}

template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
    cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice);
    return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
    T* dptr = _cudamalloc<T>(sz);
    return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
    cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost);
    return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
    T* hptr = new T[sz];
    return _d2h(dptr, hptr, sz);
}

#ifdef FMOE_USE_NCCL

#include <nccl.h>

#define UPDATE_COUNTERS(__count__) { \
    if (i == rank) { \
        lec[j] += (__count__); \
    } \
    if (j == rank) { \
        gec[i] += (__count__); \
        cap -= (__count__); \
    } \
}

std::vector<torch::Tensor> _swipe_once(
        torch::Tensor gate_idx, torch::Tensor capacity,
        long n_expert, long n_worker, long bias) {
    auto device_idx = gate_idx.device().index();
    auto smgr = getCudaStreamManager(device_idx);
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    cudaSetDevice(device_idx);

Rick Ho's avatar
Rick Ho committed
92
93
94
    auto capacity_new = capacity.clone();
    auto cap = capacity_new.item<long>();
    // fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert);
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101
102
103

    long batch_size = gate_idx.size(0);
    auto gate_idx_cpu = gate_idx.cpu();
    long* gidx = gate_idx_cpu.data_ptr<long>();

    /* Local count and exchange */
    long *lec = new long[n_worker];
    memset(lec, 0, n_worker * sizeof(long));
    for (long i = 0; i < batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
104
        ++lec[gidx[i] / n_expert];
Rick Ho's avatar
Rick Ho committed
105
106
107
    }
    long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
    fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
Rick Ho's avatar
Rick Ho committed
108
109
    long *gec = _d2h(d_gec, n_worker);
    // fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]);
Rick Ho's avatar
Rick Ho committed
110
111
112
113

    /* Limit number of incoming samples */
    long *drop_count = new long[n_worker];
    memset(drop_count, 0, n_worker * sizeof(long));
Rick Ho's avatar
Rick Ho committed
114
    for (long i = 0; i < n_worker; ++i) {
Rick Ho's avatar
Rick Ho committed
115
116
117
118
119
120
121
122
123
124
        if (cap >= gec[i]) {
            drop_count[i] = 0;
            cap -= gec[i];
        } else {
            drop_count[i] = gec[i] - cap;
            gec[i] = cap;
            cap = 0;
        }
    }

Rick Ho's avatar
Rick Ho committed
125
    // fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]);
Rick Ho's avatar
Rick Ho committed
126
127
    /* Send limit information back */
    _h2d(gec, d_gec, n_worker);
Rick Ho's avatar
Rick Ho committed
128
129
    fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
    _d2h(d_lec, lec, n_worker);
Rick Ho's avatar
Rick Ho committed
130
131
132
133
134
135
136

    auto d_dropcount = _h2d(drop_count, n_worker);
    ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
            smgr->ncclcomm, smgr->stream());
    _d2h(d_dropcount, drop_count, n_worker);

    auto d_gcap = _cudamalloc<long>(n_worker);
Rick Ho's avatar
Rick Ho committed
137
    _h2d(&cap, d_gcap + rank, 1);
Rick Ho's avatar
Rick Ho committed
138
139
140
    ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
            smgr->ncclcomm, smgr->stream());
    auto gcap = _d2h(d_gcap, n_worker);
Rick Ho's avatar
Rick Ho committed
141
    cudaDeviceSynchronize();
Rick Ho's avatar
Rick Ho committed
142

Rick Ho's avatar
Rick Ho committed
143
144
    // fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]);
    /* Re-assign and update counters */
Rick Ho's avatar
Rick Ho committed
145
146
147
148
149
150
151
152
153
154
155
156
157
    for (long i = 0, j = 0; i < n_worker; ++i) {
        while (drop_count[i] > 0) {
            if (drop_count[i] > gcap[j]) {
                drop_count[i] -= gcap[j];
                UPDATE_COUNTERS(gcap[j]);
                ++j;
            } else {
                gcap[j] -= drop_count[i];
                UPDATE_COUNTERS(drop_count[i]);
                break;
            }
        }
    }
Rick Ho's avatar
Rick Ho committed
158
    // fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]);
Rick Ho's avatar
Rick Ho committed
159
    for (long i = 0; i < batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
160
        auto widx = gidx[i] / n_expert;
Rick Ho's avatar
Rick Ho committed
161
162
163
164
165
166
167
168
169
170
171
172
        if (lec[widx] > 0) {
            --lec[widx];
        } else {
            gidx[i] = -1;
        }
    }
    for (long i = 0, k = 0; i < batch_size; ++i) {
        if (gidx[i] != -1) {
            continue;
        }
        for (; lec[k] == 0; ++k);
        --lec[gidx[i] = k * n_expert + bias];
Rick Ho's avatar
Rick Ho committed
173
        // fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k);
Rick Ho's avatar
Rick Ho committed
174
    }
Rick Ho's avatar
Rick Ho committed
175
176
    *capacity_new.data_ptr<long>() = cap;
    // fprintf(stderr, "%d all done\n", rank);
Rick Ho's avatar
Rick Ho committed
177

Rick Ho's avatar
Rick Ho committed
178
179
180
181
182
183
184
185
186
187
188
    delete [] drop_count;
    delete [] lec;
    delete [] gec;
    delete [] gcap;

    cudaFree(d_dropcount);
    cudaFree(d_lec);
    cudaFree(d_gec);
    cudaFree(d_gcap);

    return {gate_idx_cpu, capacity_new};
Rick Ho's avatar
Rick Ho committed
189
190
191
192
193
}

#undef UPDATE_COUNTERS

#endif