smart_schedule.h 14.6 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>

#include "../stream_manager.h"

Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
#if defined(CUDA_VERSION) && (CUDA_VERSION < 110010)
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__,0)
#else
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__)
#endif
Rick Ho's avatar
Rick Ho committed
19
20

template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
21
void exchangeWith(
Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        const scalar_t* sendbuf, size_t sendcount, int t_send,
        scalar_t* recvbuf, size_t recvcount, int t_recv,
        long d_model,
        cudaStream_t stream, ncclComm_t comm) {
    if (sendcount) {
        ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
                ncclChar, t_send , comm, stream);
    }
    if (recvcount) {
        ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
                ncclChar, t_recv, comm, stream);
    }
}


#define GEN_BASE(_step) \
    long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
    long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
    int idx_send = ei + rank_send * num_expert; \
    int idx_recv = ei + rank_recv * num_expert; \
    int gidx_send = ei * world_size + rank_send; \
    int gidx_recv = ei * world_size + rank_recv; \
    int idx_self = ei +      rank * num_expert;

Rick Ho's avatar
Rick Ho committed
47

Rick Ho's avatar
Rick Ho committed
48
49
50
void computePtrs(long num_expert, long rank, long world_size,
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
51
52
53
54
55
        const bool* stored_models,
        int *local_ptr,
        int *global_ptr,
        int *local_global_ptr) {
    local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
56

Rick Ho's avatar
Rick Ho committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    for (int i = 0; i < num_expert * world_size; ++i) {
        local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];

        local_global_ptr[i + 1] = local_global_ptr[i];
        // if model fetched, add local tokens
        if (stored_models[i]){
            local_global_ptr[i + 1] += local_expert_count[i];
        }

        auto expert_idx = i % num_expert;
        auto worker_idx = i / num_expert;
        auto gp_idx = expert_idx * world_size + worker_idx;
        // if local model wasn't fetched, receive global tokens
        if (stored_models[rank * num_expert + expert_idx]) {
            global_ptr[gp_idx + 1] = 0;
        } else {
            global_ptr[gp_idx + 1] = global_expert_count[i];
        }
    }
    global_ptr[0] = 0;
    for (int i = 0; i < num_expert * world_size; ++i) {
        global_ptr[i + 1] += global_ptr[i];
    }
}

Rick Ho's avatar
Rick Ho committed
82

Rick Ho's avatar
Rick Ho committed
83
template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
84
void computeFn(py::function fn, c10::Device device,
Rick Ho's avatar
Rick Ho committed
85
        scalar_t* inp_buf, scalar_t* out_buf,
86
        long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
87
        CudaStreamManager* smgr) {
88
89
90
    if(micro_batch_size == 0) {
        return;
    }
Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
96
97
98
    auto options = torch::TensorOptions()
        .dtype(c10::CppTypeToScalarType<scalar_t>::value)
        .device(device)
        .requires_grad(true);
    auto inp = torch::from_blob(inp_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
    auto oup = torch::from_blob(out_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
99
    smgr->use_default = true;
100
    fn(inp, oup, expert_idx, store_idx);
101
    smgr->use_default = false;
Rick Ho's avatar
Rick Ho committed
102
103
104
105
106
107
}


template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
        py::function forward_fn,
Rick Ho's avatar
Rick Ho committed
108
109
        py::function stash_fn,
        py::function pop_fn,
Rick Ho's avatar
Rick Ho committed
110
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
111
        std::vector<torch::Tensor> params,
Rick Ho's avatar
Rick Ho committed
112

Rick Ho's avatar
Rick Ho committed
113
        scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
114
115
116
117
        scalar_t* global_input_buf,
        scalar_t* global_output_buf,
        scalar_t* output_buf,

Rick Ho's avatar
Rick Ho committed
118
119
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
120
121
122
        const bool* stored_models,

        long d_model,
Rick Ho's avatar
Rick Ho committed
123
        long num_expert, long rank, long world_size, long expert_size,
Rick Ho's avatar
Rick Ho committed
124
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
125
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
126
127
128
129

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
Rick Ho's avatar
Rick Ho committed
130
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);

    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

Rick Ho's avatar
Rick Ho committed
147
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
148
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
149
        for (long ei = 0; ei < num_expert; ++ei) {
Rick Ho's avatar
Rick Ho committed
150
151
152
153
154
155
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
156
                exchangeWith(input_buf + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
157
158
159
160
161
162
163
164
165
166
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_input_buf + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

Rick Ho's avatar
Rick Ho committed
167
168
169
170
171
172
173
174
175
176
    // Broadcast shadowed experts
    cudaEvent_t evt_get, *evt_shadow;
    if (params.size() > 0) {
        evt_shadow = new cudaEvent_t[params.size()];
    }
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
                cudaEventCreate(&evt_get);
                cudaEventRecord(evt_get, torch_stream);
177
                FMOE_SWE(smgr->stream(0), evt_get);
Rick Ho's avatar
Rick Ho committed
178
                cudaEventDestroy(evt_get);
Rick Ho's avatar
Rick Ho committed
179
            }
Rick Ho's avatar
Rick Ho committed
180
            NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
181
182
183
184
185
186
187
188
189
                        expert_size * sizeof(scalar_t), ncclChar,
                        i / num_expert, smgr->ncclcomm, smgr->stream(0)));
            cudaEventCreate(evt_shadow + si);
            cudaEventRecord(evt_shadow[si], smgr->stream(0));
            ++si;
        }
    }

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
190
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
191
        FMOE_SWE(torch_stream, input_ready[step]);
Rick Ho's avatar
Rick Ho committed
192
193
194
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
195
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
196
                (from_base + pipeline_gran)] - offset;
Rick Ho's avatar
Rick Ho committed
197
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
198
                    global_input_buf, global_output_buf,
199
                    (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
200
        }
Rick Ho's avatar
Rick Ho committed
201
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
202
203
    }

Rick Ho's avatar
Rick Ho committed
204
205
206
    // Compute over shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
Rick Ho's avatar
Rick Ho committed
207
            FMOE_SWE(torch_stream, evt_shadow[si]);
208
            stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
Rick Ho's avatar
Rick Ho committed
209
210
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
Rick Ho's avatar
Rick Ho committed
211
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
212
                    input_buf, output_buf,
213
                    0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
214
215
216
            ++si;
        }
    }
217
    pop_fn(0);
Rick Ho's avatar
Rick Ho committed
218
219

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
220
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
221
        FMOE_SWE(smgr->stream(0), output_ready[step]);
Rick Ho's avatar
Rick Ho committed
222
223
224
225
226
227
228
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
229
                exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
230
231
232
233
234
235
236
237
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        output_buf + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }
zms1999's avatar
zms1999 committed
238
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
239
240
241
242
243
244
245
246
247

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
Rick Ho's avatar
Rick Ho committed
248
249
250
    for (unsigned i = 0; i < params.size(); ++i) {
        cudaEventDestroy(evt_shadow[i]);
    }
Rick Ho's avatar
Rick Ho committed
251
252
253
254
255
256
257
258
    delete [] input_ready;
    delete [] output_ready;
}


template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
        py::function backward_fn,
Rick Ho's avatar
Rick Ho committed
259
260
261
262
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn,
Rick Ho's avatar
Rick Ho committed
263
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
264

Rick Ho's avatar
Rick Ho committed
265
        scalar_t* grad_out,
Rick Ho's avatar
Rick Ho committed
266
267
268
269
        scalar_t* global_grad_out,
        scalar_t* global_grad_in,
        scalar_t* grad_in,

Rick Ho's avatar
Rick Ho committed
270
271
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
272
        const bool* stored_models,
Rick Ho's avatar
Rick Ho committed
273
        long d_model,
Rick Ho's avatar
Rick Ho committed
274
275
        long num_expert, long rank, long world_size,
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
276
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
277
278
279
280
281

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker

Rick Ho's avatar
Rick Ho committed
282
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);
    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

Rick Ho's avatar
Rick Ho committed
298
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
299
300
301
302
303
304
305
306
    for (long step = 0; step < n_groups; ++step) {
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
307
                exchangeWith(grad_out + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
308
309
310
311
312
313
314
315
316
317
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_grad_out + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

Rick Ho's avatar
Rick Ho committed
318
319
320
321
    // Shadowed experts backward and reduce
    cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
322
            stash_fn(si, 0);
Rick Ho's avatar
Rick Ho committed
323
324
325
326
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
            computeFn(backward_fn, device,
                    grad_out, grad_in,
327
328
                    0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
            collect_fn(si, i / num_expert, 0);
Rick Ho's avatar
Rick Ho committed
329
330
331
332
333
334
335
            if (i / num_expert == rank) {
                cudaEventCreate(evt_reduce + i % num_expert);
                cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
            }
            ++si;
        }
    }
336
    pop_fn(0);
Rick Ho's avatar
Rick Ho committed
337
338

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
339
    for (long step = 0; step < n_groups; ++step) {
340
        FMOE_SWE(torch_stream, input_ready[step]);
Rick Ho's avatar
Rick Ho committed
341
342
343
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
344
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
345
346
                (from_base + pipeline_gran)] - offset;

Rick Ho's avatar
Rick Ho committed
347
            computeFn(backward_fn, device,
Rick Ho's avatar
Rick Ho committed
348
                    global_grad_out, global_grad_in,
349
                    (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
350
        }
Rick Ho's avatar
Rick Ho committed
351
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
352
353
    }

Rick Ho's avatar
Rick Ho committed
354
355
356
357
    // Collect gradients for shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
Rick Ho's avatar
Rick Ho committed
358
                FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
359
                set_grad_fn(si, i % num_expert);
Rick Ho's avatar
Rick Ho committed
360
361
362
363
364
365
            }
            ++si;
        }
    }

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
366
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
367
        FMOE_SWE(smgr->stream(0), output_ready[step]);
Rick Ho's avatar
Rick Ho committed
368
369
370
371
372
373
374
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
375
                exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
376
377
378
379
380
381
382
383
384
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        grad_in + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

zms1999's avatar
zms1999 committed
385
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
386
387
388
389
390
391
392
393
394
395
396
397
    checkCudaErrors(cudaGetLastError());

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
    delete [] input_ready;
    delete [] output_ready;
Rick Ho's avatar
Rick Ho committed
398
399
400
401
402
403
    for (long i = 0; i < num_expert; ++i) {
        if (stored_models[i + rank * num_expert]) {
            cudaEventDestroy(evt_reduce[i]);
        }
    }
    delete [] evt_reduce;
Rick Ho's avatar
Rick Ho committed
404
405
406
}

#endif  // SMART_SCHEDULE_H