balancing.hip 5.44 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <cstdio>
#include "../hip/balancing.cuh"
#include "../hip/global_exchange.h"
#include <torch/extension.h>

/* 
 * note that due to limit of cuda atomic operator, capacity should be int32
 */
torch::Tensor _limit_by_capacity(
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_worker) {
    CHECK_INPUT(expert_count);
    CHECK_INPUT(capacity);
    auto expert_count_ack = torch::empty_like(expert_count);
    auto smgr = getCudaStreamManager(expert_count.device().index());
    fmoe_cuda_limit_by_capacity_impl(
            expert_count.data_ptr<long>(),
            capacity.data_ptr<int>(),
            expert_count_ack.data_ptr<long>(),
            n_expert, n_worker, smgr);
    return expert_count_ack;
}

torch::Tensor _prune_gate_by_capacity(
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker) {
    auto smgr = getCudaStreamManager(expert_count.device().index());
    auto batch_size = gate_idx.numel();
    auto opt = torch::TensorOptions()
        .dtype(gate_idx.dtype())
        .device(gate_idx.device());
    auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
    fmoe_cuda_prune_gate_by_capacity_impl(
            gate_idx.data_ptr<long>(),
            new_gate_idx.data_ptr<long>(),
            expert_count.data_ptr<int>(),
            batch_size, n_expert, n_worker, smgr);
    return new_gate_idx;
}

template<class T>
T* _cudamalloc(size_t sz) {
    T* dptr;
    hipMalloc(&dptr, sz * sizeof(T));
    return dptr;
}

template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
    hipMemcpy(dptr, hptr, sz * sizeof(T), hipMemcpyHostToDevice);
    return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
    T* dptr = _cudamalloc<T>(sz);
    return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
    hipMemcpy(hptr, dptr, sz * sizeof(T), hipMemcpyDeviceToHost);
    return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
    T* hptr = new T[sz];
    return _d2h(dptr, hptr, sz);
}

#ifdef FMOE_USE_NCCL

#include <rccl/rccl.h>

#define UPDATE_COUNTERS(__count__) { \
    if (i == rank) { \
        lec[j] += (__count__); \
    } \
    if (j == rank) { \
        gec[i] += (__count__); \
        cap -= (__count__); \
    } \
}

std::vector<torch::Tensor> _swipe_once(
        torch::Tensor gate_idx, torch::Tensor capacity,
        long n_expert, long n_worker, long bias) {
    auto device_idx = gate_idx.device().index();
    auto smgr = getCudaStreamManager(device_idx);
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    hipSetDevice(device_idx);

    auto capacity_new = capacity.clone();
    auto cap = capacity_new.item<long>();

    long batch_size = gate_idx.size(0);
    auto gate_idx_cpu = gate_idx.cpu();
    long* gidx = gate_idx_cpu.data_ptr<long>();

    /* Local count and exchange */
    long *lec = new long[n_worker];
    memset(lec, 0, n_worker * sizeof(long));
    for (long i = 0; i < batch_size; ++i) {
        ++lec[gidx[i] / n_expert];
    }
    long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
    fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
    smgr->syncTorch();
    long *gec = _d2h(d_gec, n_worker);

    /* Limit number of incoming samples */
    long *drop_count = new long[n_worker];
    memset(drop_count, 0, n_worker * sizeof(long));
    for (long i = 0; i < n_worker; ++i) {
        if (cap >= gec[i]) {
            drop_count[i] = 0;
            cap -= gec[i];
        } else {
            drop_count[i] = gec[i] - cap;
            gec[i] = cap;
            cap = 0;
        }
    }

    /* Send limit information back */
    _h2d(gec, d_gec, n_worker);
    fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
    smgr->syncTorch();
    _d2h(d_lec, lec, n_worker);

    auto d_dropcount = _h2d(drop_count, n_worker);
    ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
            smgr->ncclcomm, smgr->torchStream());
    smgr->syncTorch();
    _d2h(d_dropcount, drop_count, n_worker);

    auto d_gcap = _cudamalloc<long>(n_worker);
    _h2d(&cap, d_gcap + rank, 1);
    ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
            smgr->ncclcomm, smgr->torchStream());
    smgr->syncTorch();
    auto gcap = _d2h(d_gcap, n_worker);

    /* Re-assign and update counters */
    for (long i = 0, j = 0; i < n_worker; ++i) {
        while (drop_count[i] > 0) {
            if (drop_count[i] > gcap[j]) {
                drop_count[i] -= gcap[j];
                UPDATE_COUNTERS(gcap[j]);
                ++j;
            } else {
                gcap[j] -= drop_count[i];
                UPDATE_COUNTERS(drop_count[i]);
                break;
            }
        }
    }
    for (long i = 0; i < batch_size; ++i) {
        auto widx = gidx[i] / n_expert;
        if (lec[widx] > 0) {
            --lec[widx];
        } else {
            gidx[i] = -1;
        }
    }
    for (long i = 0, k = 0; i < batch_size; ++i) {
        if (gidx[i] != -1) {
            continue;
        }
        for (; lec[k] == 0; ++k);
        --lec[k];
        gidx[i] = k * n_expert + bias;
    }
    *capacity_new.data_ptr<long>() = cap;

    delete [] drop_count;
    delete [] lec;
    delete [] gec;
    delete [] gcap;

    hipFree(d_dropcount);
    hipFree(d_lec);
    hipFree(d_gec);
    hipFree(d_gcap);

    return {gate_idx_cpu, capacity_new};
}

#undef UPDATE_COUNTERS

#endif