smart_schedule.h 14.2 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>

#include "../stream_manager.h"


template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
16
void exchangeWith(
Rick Ho's avatar
Rick Ho committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        const scalar_t* sendbuf, size_t sendcount, int t_send,
        scalar_t* recvbuf, size_t recvcount, int t_recv,
        long d_model,
        cudaStream_t stream, ncclComm_t comm) {
    if (sendcount) {
        ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
                ncclChar, t_send , comm, stream);
    }
    if (recvcount) {
        ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
                ncclChar, t_recv, comm, stream);
    }
}


#define GEN_BASE(_step) \
    long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
    long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
    int idx_send = ei + rank_send * num_expert; \
    int idx_recv = ei + rank_recv * num_expert; \
    int gidx_send = ei * world_size + rank_send; \
    int gidx_recv = ei * world_size + rank_recv; \
    int idx_self = ei +      rank * num_expert;

Rick Ho's avatar
Rick Ho committed
42

Rick Ho's avatar
Rick Ho committed
43
44
45
void computePtrs(long num_expert, long rank, long world_size,
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
46
47
48
49
50
        const bool* stored_models,
        int *local_ptr,
        int *global_ptr,
        int *local_global_ptr) {
    local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
51

Rick Ho's avatar
Rick Ho committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    for (int i = 0; i < num_expert * world_size; ++i) {
        local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];

        local_global_ptr[i + 1] = local_global_ptr[i];
        // if model fetched, add local tokens
        if (stored_models[i]){
            local_global_ptr[i + 1] += local_expert_count[i];
        }

        auto expert_idx = i % num_expert;
        auto worker_idx = i / num_expert;
        auto gp_idx = expert_idx * world_size + worker_idx;
        // if local model wasn't fetched, receive global tokens
        if (stored_models[rank * num_expert + expert_idx]) {
            global_ptr[gp_idx + 1] = 0;
        } else {
            global_ptr[gp_idx + 1] = global_expert_count[i];
        }
    }
    global_ptr[0] = 0;
    for (int i = 0; i < num_expert * world_size; ++i) {
        global_ptr[i + 1] += global_ptr[i];
    }
}

Rick Ho's avatar
Rick Ho committed
77

Rick Ho's avatar
Rick Ho committed
78
template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
79
void computeFn(py::function fn, c10::Device device,
Rick Ho's avatar
Rick Ho committed
80
        scalar_t* inp_buf, scalar_t* out_buf,
Rick Ho's avatar
Rick Ho committed
81
        long idx, long offset, long micro_batch_size, long d_model,
82
        CudaStreamManager* smgr) {
83
84
85
    if(micro_batch_size == 0) {
        return;
    }
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
91
92
93
    auto options = torch::TensorOptions()
        .dtype(c10::CppTypeToScalarType<scalar_t>::value)
        .device(device)
        .requires_grad(true);
    auto inp = torch::from_blob(inp_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
    auto oup = torch::from_blob(out_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
94
    smgr->use_default = true;
Rick Ho's avatar
Rick Ho committed
95
    fn(inp, oup, idx);
96
    smgr->use_default = false;
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
102
}


template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
        py::function forward_fn,
Rick Ho's avatar
Rick Ho committed
103
104
        py::function stash_fn,
        py::function pop_fn,
Rick Ho's avatar
Rick Ho committed
105
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
106
        std::vector<torch::Tensor> params,
Rick Ho's avatar
Rick Ho committed
107

Rick Ho's avatar
Rick Ho committed
108
        scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
109
110
111
112
        scalar_t* global_input_buf,
        scalar_t* global_output_buf,
        scalar_t* output_buf,

Rick Ho's avatar
Rick Ho committed
113
114
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
115
116
117
        const bool* stored_models,

        long d_model,
Rick Ho's avatar
Rick Ho committed
118
        long num_expert, long rank, long world_size, long expert_size,
Rick Ho's avatar
Rick Ho committed
119
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
120
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
121
122
123
124

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
Rick Ho's avatar
Rick Ho committed
125
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);

    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

Rick Ho's avatar
Rick Ho committed
142
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
143
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
144
        for (long ei = 0; ei < num_expert; ++ei) {
Rick Ho's avatar
Rick Ho committed
145
146
147
148
149
150
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
151
                exchangeWith(input_buf + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
152
153
154
155
156
157
158
159
160
161
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_input_buf + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

Rick Ho's avatar
Rick Ho committed
162
163
164
165
166
167
168
169
170
171
172
    // Broadcast shadowed experts
    cudaEvent_t evt_get, *evt_shadow;
    if (params.size() > 0) {
        evt_shadow = new cudaEvent_t[params.size()];
    }
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
                cudaEventCreate(&evt_get);
                cudaEventRecord(evt_get, torch_stream);
                cudaStreamWaitEvent(smgr->stream(1), evt_get);
Rick Ho's avatar
Rick Ho committed
173
                cudaEventDestroy(evt_get);
Rick Ho's avatar
Rick Ho committed
174
            }
Rick Ho's avatar
Rick Ho committed
175
            NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
176
177
178
179
180
181
182
183
184
                        expert_size * sizeof(scalar_t), ncclChar,
                        i / num_expert, smgr->ncclcomm, smgr->stream(0)));
            cudaEventCreate(evt_shadow + si);
            cudaEventRecord(evt_shadow[si], smgr->stream(0));
            ++si;
        }
    }

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
185
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
186
        cudaStreamWaitEvent(torch_stream, input_ready[step], 0);
Rick Ho's avatar
Rick Ho committed
187
188
189
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
190
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
191
                (from_base + pipeline_gran)] - offset;
Rick Ho's avatar
Rick Ho committed
192
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
193
                    global_input_buf, global_output_buf,
Rick Ho's avatar
Rick Ho committed
194
                    step, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
195
        }
Rick Ho's avatar
Rick Ho committed
196
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
197
198
    }

Rick Ho's avatar
Rick Ho committed
199
200
201
202
203
204
205
    // Compute over shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            stash_fn(params[si], si);
            cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0);
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
Rick Ho's avatar
Rick Ho committed
206
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
207
208
209
210
211
212
213
214
                    input_buf, output_buf,
                    n_groups + si, offset, micro_batch_size, d_model, smgr);
            ++si;
        }
    }
    pop_fn();

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
215
216
217
218
219
220
221
222
223
    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
224
                exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        output_buf + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
Rick Ho's avatar
Rick Ho committed
242
243
244
    for (unsigned i = 0; i < params.size(); ++i) {
        cudaEventDestroy(evt_shadow[i]);
    }
Rick Ho's avatar
Rick Ho committed
245
246
247
248
249
250
251
252
    delete [] input_ready;
    delete [] output_ready;
}


template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
        py::function backward_fn,
Rick Ho's avatar
Rick Ho committed
253
254
255
256
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn,
Rick Ho's avatar
Rick Ho committed
257
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
258

Rick Ho's avatar
Rick Ho committed
259
        scalar_t* grad_out,
Rick Ho's avatar
Rick Ho committed
260
261
262
263
        scalar_t* global_grad_out,
        scalar_t* global_grad_in,
        scalar_t* grad_in,

Rick Ho's avatar
Rick Ho committed
264
265
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
266
        const bool* stored_models,
Rick Ho's avatar
Rick Ho committed
267
        long d_model,
Rick Ho's avatar
Rick Ho committed
268
269
        long num_expert, long rank, long world_size,
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
270
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
271
272
273
274
275

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker

Rick Ho's avatar
Rick Ho committed
276
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);
    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

Rick Ho's avatar
Rick Ho committed
292
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
293
294
295
296
297
298
299
300
    for (long step = 0; step < n_groups; ++step) {
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
301
                exchangeWith(grad_out + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
302
303
304
305
306
307
308
309
310
311
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_grad_out + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

Rick Ho's avatar
Rick Ho committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    // Shadowed experts backward and reduce
    cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            stash_fn(si);
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
            computeFn(backward_fn, device,
                    grad_out, grad_in,
                    n_groups + si, offset, micro_batch_size, d_model, smgr);
            collect_fn(si, i / num_expert);
            if (i / num_expert == rank) {
                cudaEventCreate(evt_reduce + i % num_expert);
                cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
            }
            ++si;
        }
    }
    pop_fn();

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
333
334
335
336
337
    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
338
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
339
340
                (from_base + pipeline_gran)] - offset;

Rick Ho's avatar
Rick Ho committed
341
            computeFn(backward_fn, device,
Rick Ho's avatar
Rick Ho committed
342
                    global_grad_out, global_grad_in,
Rick Ho's avatar
Rick Ho committed
343
                    step, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
344
        }
Rick Ho's avatar
Rick Ho committed
345
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
346
347
    }

Rick Ho's avatar
Rick Ho committed
348
349
350
351
352
353
354
355
356
357
358
359
    // Collect gradients for shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
                cudaStreamWaitEvent(torch_stream, evt_reduce[i % num_expert], 0);
                set_grad_fn(si);
            }
            ++si;
        }
    }

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
360
361
362
363
364
365
366
367
368
    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
369
                exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        grad_in + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

    checkCudaErrors(cudaGetLastError());

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
    delete [] input_ready;
    delete [] output_ready;
Rick Ho's avatar
Rick Ho committed
391
392
393
394
395
396
    for (long i = 0; i < num_expert; ++i) {
        if (stored_models[i + rank * num_expert]) {
            cudaEventDestroy(evt_reduce[i]);
        }
    }
    delete [] evt_reduce;
Rick Ho's avatar
Rick Ho committed
397
398
399
}

#endif  // SMART_SCHEDULE_H