smart_schedule.h 14.1 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>

#include "../stream_manager.h"


template<typename scalar_t>
void _exchange_with(
        const scalar_t* sendbuf, size_t sendcount, int t_send,
        scalar_t* recvbuf, size_t recvcount, int t_recv,
        long d_model,
        cudaStream_t stream, ncclComm_t comm) {
    if (sendcount) {
        ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
                ncclChar, t_send , comm, stream);
    }
    if (recvcount) {
        ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
                ncclChar, t_recv, comm, stream);
    }
}


#define GEN_BASE(_step) \
    long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
    long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
    int idx_send = ei + rank_send * num_expert; \
    int idx_recv = ei + rank_recv * num_expert; \
    int gidx_send = ei * world_size + rank_send; \
    int gidx_recv = ei * world_size + rank_recv; \
    int idx_self = ei +      rank * num_expert;

Rick Ho's avatar
Rick Ho committed
42

Rick Ho's avatar
Rick Ho committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
void _compute_ptrs(long num_expert, long rank, long world_size, 
        const long* local_expert_count, 
        const long* global_expert_count, 
        const bool* stored_models,
        int *local_ptr,
        int *global_ptr,
        int *local_global_ptr) {
    local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
    
    for (int i = 0; i < num_expert * world_size; ++i) {
        local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];

        local_global_ptr[i + 1] = local_global_ptr[i];
        // if model fetched, add local tokens
        if (stored_models[i]){
            local_global_ptr[i + 1] += local_expert_count[i];
        }

        auto expert_idx = i % num_expert;
        auto worker_idx = i / num_expert;
        auto gp_idx = expert_idx * world_size + worker_idx;
        // if local model wasn't fetched, receive global tokens
        if (stored_models[rank * num_expert + expert_idx]) {
            global_ptr[gp_idx + 1] = 0;
        } else {
            global_ptr[gp_idx + 1] = global_expert_count[i];
        }
    }
    global_ptr[0] = 0;
    for (int i = 0; i < num_expert * world_size; ++i) {
        global_ptr[i + 1] += global_ptr[i];
    }
}

Rick Ho's avatar
Rick Ho committed
77

Rick Ho's avatar
Rick Ho committed
78
template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
79
void _compute_fn(py::function fn, c10::Device device,
Rick Ho's avatar
Rick Ho committed
80
        scalar_t* inp_buf, scalar_t* out_buf,
Rick Ho's avatar
Rick Ho committed
81
        long idx, long offset, long micro_batch_size, long d_model,
82
        CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
83
84
85
86
87
88
89
90
    auto options = torch::TensorOptions()
        .dtype(c10::CppTypeToScalarType<scalar_t>::value)
        .device(device)
        .requires_grad(true);
    auto inp = torch::from_blob(inp_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
    auto oup = torch::from_blob(out_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
91
    smgr->use_default = true;
Rick Ho's avatar
Rick Ho committed
92
    fn(inp, oup, idx);
93
    smgr->use_default = false;
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
}


template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
        py::function forward_fn,
Rick Ho's avatar
Rick Ho committed
100
101
        py::function stash_fn,
        py::function pop_fn,
Rick Ho's avatar
Rick Ho committed
102
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
103
        std::vector<torch::Tensor> params,
Rick Ho's avatar
Rick Ho committed
104

Rick Ho's avatar
Rick Ho committed
105
        scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
106
107
108
109
110
111
112
113
114
        scalar_t* global_input_buf,
        scalar_t* global_output_buf,
        scalar_t* output_buf,

        const long* local_expert_count, 
        const long* global_expert_count, 
        const bool* stored_models,

        long d_model,
Rick Ho's avatar
Rick Ho committed
115
        long num_expert, long rank, long world_size, long expert_size,
Rick Ho's avatar
Rick Ho committed
116
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
117
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
    _compute_ptrs(num_expert, rank, world_size,
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);

    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

Rick Ho's avatar
Rick Ho committed
139
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
140
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
141
        for (long ei = 0; ei < num_expert; ++ei) {
Rick Ho's avatar
Rick Ho committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
                _exchange_with(input_buf + local_ptr[idx_send] * d_model,
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_input_buf + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

Rick Ho's avatar
Rick Ho committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    // Broadcast shadowed experts
    cudaEvent_t evt_get, *evt_shadow;
    if (params.size() > 0) {
        evt_shadow = new cudaEvent_t[params.size()];
    }
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
                cudaEventCreate(&evt_get);
                cudaEventRecord(evt_get, torch_stream);
                cudaStreamWaitEvent(smgr->stream(1), evt_get);
            }
            NCCL_SAFE_CALL(ncclBcast(params[si].data_ptr<void>(), 
                        expert_size * sizeof(scalar_t), ncclChar,
                        i / num_expert, smgr->ncclcomm, smgr->stream(0)));
            cudaEventCreate(evt_shadow + si);
            cudaEventRecord(evt_shadow[si], smgr->stream(0));
            ++si;
        }
    }

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
181
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
182
        cudaStreamWaitEvent(torch_stream, input_ready[step], 0);
Rick Ho's avatar
Rick Ho committed
183
184
185
186
187
188
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
            long micro_batch_size = global_ptr[ei * world_size + 
                (from_base + pipeline_gran)] - offset;
            
Rick Ho's avatar
Rick Ho committed
189
            _compute_fn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
190
                    global_input_buf, global_output_buf,
Rick Ho's avatar
Rick Ho committed
191
                    step, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
192
        }
Rick Ho's avatar
Rick Ho committed
193
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
194
195
    }

Rick Ho's avatar
Rick Ho committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    // Compute over shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            stash_fn(params[si], si);
            cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0);
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
            _compute_fn(forward_fn, device,
                    input_buf, output_buf,
                    n_groups + si, offset, micro_batch_size, d_model, smgr);
            ++si;
        }
    }
    pop_fn();

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
                _exchange_with(global_output_buf + global_ptr[gidx_send] * d_model,
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        output_buf + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
Rick Ho's avatar
Rick Ho committed
239
240
241
    for (unsigned i = 0; i < params.size(); ++i) {
        cudaEventDestroy(evt_shadow[i]);
    }
Rick Ho's avatar
Rick Ho committed
242
243
    delete [] input_ready;
    delete [] output_ready;
Rick Ho's avatar
Rick Ho committed
244
245
246
    if (params.size()) {
        delete [] evt_shadow;
    }
Rick Ho's avatar
Rick Ho committed
247
248
249
250
251
252
}


template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
        py::function backward_fn,
Rick Ho's avatar
Rick Ho committed
253
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
254

Rick Ho's avatar
Rick Ho committed
255
        scalar_t* grad_out,
Rick Ho's avatar
Rick Ho committed
256
257
258
259
260
261
262
        scalar_t* global_grad_out,
        scalar_t* global_grad_in,
        scalar_t* grad_in,

        const long* local_expert_count, 
        const long* global_expert_count, 
        const bool* stored_models,
Rick Ho's avatar
Rick Ho committed
263
        long d_model,
Rick Ho's avatar
Rick Ho committed
264
265
        long num_expert, long rank, long world_size,
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
266
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Rick Ho's avatar
Rick Ho committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker

    _compute_ptrs(num_expert, rank, world_size,
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);
   
    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
    }

    for (long step = 0; step < n_groups; ++step) {
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
                _exchange_with(grad_out + local_ptr[idx_send] * d_model,
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_grad_out + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
        cudaEventRecord(input_ready[step], smgr->stream(0));
    }

    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
            long micro_batch_size = global_ptr[ei * world_size + 
                (from_base + pipeline_gran)] - offset;

Rick Ho's avatar
Rick Ho committed
316
317
            _compute_fn(backward_fn, device,
                    global_grad_out, global_grad_in,
Rick Ho's avatar
Rick Ho committed
318
                    step, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
319
        }
Rick Ho's avatar
Rick Ho committed
320
        cudaEventRecord(output_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    }

    for (long step = 0; step < n_groups; ++step) {
        cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
                _exchange_with(global_grad_in + global_ptr[gidx_send] * d_model,
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        grad_in + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
                        d_model, smgr->stream(0), smgr->ncclcomm);
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

    checkCudaErrors(cudaGetLastError());

    /* TODO: Shadowing support
    int offset = global_ptr[world_size * num_expert];
    for (int j = 0; j < world_size; j++) {
        
        for (int i = 0; i < num_expert; i++) {
            int idx = j * num_expert + i;
            if (!stored_models[idx])
                continue;
            
            weight1 = params[j][0][0].data_ptr<scalar_t>();
            weight2 = params[j][0][last].data_ptr<scalar_t>();    
            grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
            grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
            
            auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));

            _compute_mlp_backward(
                original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
                middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
                grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
                i,
                0, local_expert_count[idx],
                d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
                smgr->stream(stream), smgr->handle(stream));

        }
    }
    */


    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
    }
    delete [] input_ready;
    delete [] output_ready;
}

#endif  // SMART_SCHEDULE_H