smart_schedule.h 15.6 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>

#include "../stream_manager.h"

Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
#if defined(CUDA_VERSION) && (CUDA_VERSION < 110010)
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__,0)
#else
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__)
#endif
Rick Ho's avatar
Rick Ho committed
19
20

template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
21
void exchangeWith(
Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        const scalar_t* sendbuf, size_t sendcount, int t_send,
        scalar_t* recvbuf, size_t recvcount, int t_recv,
        long d_model,
        cudaStream_t stream, ncclComm_t comm) {
    if (sendcount) {
        ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
                ncclChar, t_send , comm, stream);
    }
    if (recvcount) {
        ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
                ncclChar, t_recv, comm, stream);
    }
}


#define GEN_BASE(_step) \
    long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
    long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
    int idx_send = ei + rank_send * num_expert; \
    int idx_recv = ei + rank_recv * num_expert; \
    int gidx_send = ei * world_size + rank_send; \
    int gidx_recv = ei * world_size + rank_recv; \
    int idx_self = ei +      rank * num_expert;

Rick Ho's avatar
Rick Ho committed
47

Rick Ho's avatar
Rick Ho committed
48
49
50
void computePtrs(long num_expert, long rank, long world_size,
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
51
52
53
54
55
        const bool* stored_models,
        int *local_ptr,
        int *global_ptr,
        int *local_global_ptr) {
    local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
56

Rick Ho's avatar
Rick Ho committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    for (int i = 0; i < num_expert * world_size; ++i) {
        local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];

        local_global_ptr[i + 1] = local_global_ptr[i];
        // if model fetched, add local tokens
        if (stored_models[i]){
            local_global_ptr[i + 1] += local_expert_count[i];
        }

        auto expert_idx = i % num_expert;
        auto worker_idx = i / num_expert;
        auto gp_idx = expert_idx * world_size + worker_idx;
        // if local model wasn't fetched, receive global tokens
        if (stored_models[rank * num_expert + expert_idx]) {
            global_ptr[gp_idx + 1] = 0;
        } else {
            global_ptr[gp_idx + 1] = global_expert_count[i];
        }
    }
    global_ptr[0] = 0;
    for (int i = 0; i < num_expert * world_size; ++i) {
        global_ptr[i + 1] += global_ptr[i];
    }
}

Rick Ho's avatar
Rick Ho committed
82

Rick Ho's avatar
Rick Ho committed
83
template<typename scalar_t>
Rick Ho's avatar
Rick Ho committed
84
void computeFn(py::function fn, c10::Device device,
Rick Ho's avatar
Rick Ho committed
85
        scalar_t* inp_buf, scalar_t* out_buf,
86
        long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
87
        CudaStreamManager* smgr) {
88
89
90
    if(micro_batch_size == 0) {
        return;
    }
Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
96
97
98
    auto options = torch::TensorOptions()
        .dtype(c10::CppTypeToScalarType<scalar_t>::value)
        .device(device)
        .requires_grad(true);
    auto inp = torch::from_blob(inp_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
    auto oup = torch::from_blob(out_buf + offset * d_model,
            {micro_batch_size, d_model}, options);
99
    smgr->use_default = true;
100
    fn(inp, oup, expert_idx, store_idx);
101
    smgr->use_default = false;
Rick Ho's avatar
Rick Ho committed
102
103
104
105
106
107
}


template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
        py::function forward_fn,
Rick Ho's avatar
Rick Ho committed
108
109
        py::function stash_fn,
        py::function pop_fn,
Rick Ho's avatar
Rick Ho committed
110
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
111
        std::vector<torch::Tensor> params,
Rick Ho's avatar
Rick Ho committed
112

Rick Ho's avatar
Rick Ho committed
113
        scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
114
115
116
117
        scalar_t* global_input_buf,
        scalar_t* global_output_buf,
        scalar_t* output_buf,

Rick Ho's avatar
Rick Ho committed
118
119
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
120
121
122
        const bool* stored_models,

        long d_model,
Rick Ho's avatar
Rick Ho committed
123
        long num_expert, long rank, long world_size, long expert_size,
Rick Ho's avatar
Rick Ho committed
124
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
125
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
126
    cudaStreamSynchronize(torch_stream);
Rick Ho's avatar
Rick Ho committed
127
128
129
130

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
Rick Ho's avatar
Rick Ho committed
131
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
132
133
134
135
136
137
138
139
140
141
142
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);

    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
zms1999's avatar
zms1999 committed
143
    cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
Rick Ho's avatar
Rick Ho committed
144
145
146
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
zms1999's avatar
zms1999 committed
147
        cudaEventCreate(output_torch_ready + i);
Rick Ho's avatar
Rick Ho committed
148
149
    }

Rick Ho's avatar
Rick Ho committed
150
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
151
    for (long step = 0; step < n_groups; ++step) {
Rick Ho's avatar
Rick Ho committed
152
        for (long ei = 0; ei < num_expert; ++ei) {
Rick Ho's avatar
Rick Ho committed
153
154
155
156
157
158
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
159
                exchangeWith(input_buf + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
160
161
162
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_input_buf + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
163
                        d_model, smgr->stream(num_expert), smgr->ncclcomm);
Rick Ho's avatar
Rick Ho committed
164
165
166
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
167
        cudaEventRecord(input_ready[step], smgr->stream(num_expert));
Rick Ho's avatar
Rick Ho committed
168
169
    }

Rick Ho's avatar
Rick Ho committed
170
171
172
173
174
175
176
177
178
    // Broadcast shadowed experts
    cudaEvent_t evt_get, *evt_shadow;
    if (params.size() > 0) {
        evt_shadow = new cudaEvent_t[params.size()];
    }
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
                cudaEventCreate(&evt_get);
179
180
                cudaEventRecord(evt_get, smgr->stream(0));
                FMOE_SWE(smgr->stream(num_expert), evt_get);
Rick Ho's avatar
Rick Ho committed
181
                cudaEventDestroy(evt_get);
Rick Ho's avatar
Rick Ho committed
182
            }
Rick Ho's avatar
Rick Ho committed
183
            NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
184
                        expert_size * sizeof(scalar_t), ncclChar,
185
                        i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
Rick Ho's avatar
Rick Ho committed
186
            cudaEventCreate(evt_shadow + si);
187
            cudaEventRecord(evt_shadow[si], smgr->stream(num_expert));
Rick Ho's avatar
Rick Ho committed
188
189
190
191
192
            ++si;
        }
    }

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
193
    for (long step = 0; step < n_groups; ++step) {
194
        FMOE_SWE(smgr->stream(0), input_ready[step]);
Rick Ho's avatar
Rick Ho committed
195
        FMOE_SWE(torch_stream, input_ready[step]);
Rick Ho's avatar
Rick Ho committed
196
197
198
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
199
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
200
                (from_base + pipeline_gran)] - offset;
Rick Ho's avatar
Rick Ho committed
201
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
202
                    global_input_buf, global_output_buf,
203
                    (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
204
        }
205
        cudaEventRecord(output_ready[step], smgr->stream(0));
zms1999's avatar
zms1999 committed
206
        cudaEventRecord(output_torch_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
207
208
    }

Rick Ho's avatar
Rick Ho committed
209
210
211
    // Compute over shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
212
            FMOE_SWE(smgr->stream(0), evt_shadow[si]);
Rick Ho's avatar
Rick Ho committed
213
            FMOE_SWE(torch_stream, evt_shadow[si]);
214
            stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
Rick Ho's avatar
Rick Ho committed
215
216
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
Rick Ho's avatar
Rick Ho committed
217
            computeFn(forward_fn, device,
Rick Ho's avatar
Rick Ho committed
218
                    input_buf, output_buf,
219
                    0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
220
221
222
            ++si;
        }
    }
223
    pop_fn(0);
Rick Ho's avatar
Rick Ho committed
224
225

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
226
    for (long step = 0; step < n_groups; ++step) {
227
        FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
zms1999's avatar
zms1999 committed
228
        FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
Rick Ho's avatar
Rick Ho committed
229
230
231
232
233
234
235
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
236
                exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
237
238
239
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        output_buf + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
240
                        d_model, smgr->stream(num_expert), smgr->ncclcomm);
Rick Ho's avatar
Rick Ho committed
241
242
243
244
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }
245
    smgr->sync(num_expert + 1);
Rick Ho's avatar
Rick Ho committed
246
247
248
249
250
251
252
253

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
zms1999's avatar
zms1999 committed
254
        cudaEventDestroy(output_torch_ready[i]);
Rick Ho's avatar
Rick Ho committed
255
    }
Rick Ho's avatar
Rick Ho committed
256
257
258
    for (unsigned i = 0; i < params.size(); ++i) {
        cudaEventDestroy(evt_shadow[i]);
    }
Rick Ho's avatar
Rick Ho committed
259
260
    delete [] input_ready;
    delete [] output_ready;
zms1999's avatar
zms1999 committed
261
    delete [] output_torch_ready;
Rick Ho's avatar
Rick Ho committed
262
263
264
265
266
267
}


template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
        py::function backward_fn,
Rick Ho's avatar
Rick Ho committed
268
269
270
271
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn,
Rick Ho's avatar
Rick Ho committed
272
        c10::Device device,
Rick Ho's avatar
Rick Ho committed
273

Rick Ho's avatar
Rick Ho committed
274
        scalar_t* grad_out,
Rick Ho's avatar
Rick Ho committed
275
276
277
278
        scalar_t* global_grad_out,
        scalar_t* global_grad_in,
        scalar_t* grad_in,

Rick Ho's avatar
Rick Ho committed
279
280
        const long* local_expert_count,
        const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
281
        const bool* stored_models,
Rick Ho's avatar
Rick Ho committed
282
        long d_model,
Rick Ho's avatar
Rick Ho committed
283
284
        long num_expert, long rank, long world_size,
        long pipeline_gran, CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
285
    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
286
    cudaStreamSynchronize(torch_stream);
Rick Ho's avatar
Rick Ho committed
287
288
289
290
291

    int *local_ptr = new int[num_expert * world_size + 1];
    int *global_ptr = new int[num_expert * world_size + 1];
    int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker

Rick Ho's avatar
Rick Ho committed
292
    computePtrs(num_expert, rank, world_size,
Rick Ho's avatar
Rick Ho committed
293
294
295
296
297
298
299
300
301
302
            local_expert_count, global_expert_count, stored_models,
            local_ptr, global_ptr, local_global_ptr);
    if (pipeline_gran > world_size) {
        pipeline_gran = world_size;
    }
    long n_groups = world_size / pipeline_gran;
    long group_rank = rank / pipeline_gran;

    cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
    cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
zms1999's avatar
zms1999 committed
303
    cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
Rick Ho's avatar
Rick Ho committed
304
305
306
    for (long i = 0; i < n_groups; ++i) {
        cudaEventCreate(input_ready + i);
        cudaEventCreate(output_ready + i);
zms1999's avatar
zms1999 committed
307
        cudaEventCreate(output_torch_ready + i);
Rick Ho's avatar
Rick Ho committed
308
309
    }

Rick Ho's avatar
Rick Ho committed
310
    // S_0 ... S_n
Rick Ho's avatar
Rick Ho committed
311
312
313
314
315
316
317
318
    for (long step = 0; step < n_groups; ++step) {
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + to_base;
                int rank_recv = j + from_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
319
                exchangeWith(grad_out + local_ptr[idx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
320
321
322
                        local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
                        global_grad_out + global_ptr[gidx_recv] * d_model,
                        global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
323
                        d_model, smgr->stream(num_expert), smgr->ncclcomm);
Rick Ho's avatar
Rick Ho committed
324
325
326
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
327
        cudaEventRecord(input_ready[step], smgr->stream(num_expert));
Rick Ho's avatar
Rick Ho committed
328
329
    }

Rick Ho's avatar
Rick Ho committed
330
331
332
333
    // Shadowed experts backward and reduce
    cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
334
            stash_fn(si, 0);
Rick Ho's avatar
Rick Ho committed
335
336
337
338
            long offset = local_ptr[i];
            long micro_batch_size = local_expert_count[i];
            computeFn(backward_fn, device,
                    grad_out, grad_in,
339
340
                    0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
            collect_fn(si, i / num_expert, 0);
Rick Ho's avatar
Rick Ho committed
341
342
            if (i / num_expert == rank) {
                cudaEventCreate(evt_reduce + i % num_expert);
343
                cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
Rick Ho's avatar
Rick Ho committed
344
345
346
347
            }
            ++si;
        }
    }
348
    pop_fn(0);
Rick Ho's avatar
Rick Ho committed
349
350

    // C_0 ... C_n
Rick Ho's avatar
Rick Ho committed
351
    for (long step = 0; step < n_groups; ++step) {
352
        FMOE_SWE(smgr->stream(0), input_ready[step]);
353
        FMOE_SWE(torch_stream, input_ready[step]);
Rick Ho's avatar
Rick Ho committed
354
355
356
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            long offset = global_ptr[ei * world_size + from_base];
Rick Ho's avatar
Rick Ho committed
357
            long micro_batch_size = global_ptr[ei * world_size +
Rick Ho's avatar
Rick Ho committed
358
359
                (from_base + pipeline_gran)] - offset;

Rick Ho's avatar
Rick Ho committed
360
            computeFn(backward_fn, device,
Rick Ho's avatar
Rick Ho committed
361
                    global_grad_out, global_grad_in,
362
                    (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
Rick Ho's avatar
Rick Ho committed
363
        }
364
        cudaEventRecord(output_ready[step], smgr->stream(0));
zms1999's avatar
zms1999 committed
365
        cudaEventRecord(output_torch_ready[step], torch_stream);
Rick Ho's avatar
Rick Ho committed
366
367
    }

Rick Ho's avatar
Rick Ho committed
368
369
370
371
    // Collect gradients for shadowed experts
    for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
        if (stored_models[i]) {
            if (i / num_expert == rank) {
372
                FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
Rick Ho's avatar
Rick Ho committed
373
                FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
374
                set_grad_fn(si, i % num_expert);
Rick Ho's avatar
Rick Ho committed
375
376
377
378
379
380
            }
            ++si;
        }
    }

    // R_0 ... R_n
Rick Ho's avatar
Rick Ho committed
381
    for (long step = 0; step < n_groups; ++step) {
382
        FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
zms1999's avatar
zms1999 committed
383
        FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
Rick Ho's avatar
Rick Ho committed
384
385
386
387
388
389
390
        for (int ei = 0; ei < num_expert; ++ei) {
            GEN_BASE(step);
            NCCL_SAFE_CALL(ncclGroupStart());
            for (int j = 0; j < pipeline_gran; ++j) {
                int rank_send = j + from_base;
                int rank_recv = j + to_base;
                GEN_IDX;
Rick Ho's avatar
Rick Ho committed
391
                exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
Rick Ho's avatar
Rick Ho committed
392
393
394
                        global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
                        grad_in + local_ptr[idx_recv] * d_model,
                        local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
395
                        d_model, smgr->stream(num_expert), smgr->ncclcomm);
Rick Ho's avatar
Rick Ho committed
396
397
398
399
400
            }
            NCCL_SAFE_CALL(ncclGroupEnd());
        }
    }

401
    smgr->sync(num_expert + 1);
Rick Ho's avatar
Rick Ho committed
402
403
404
405
406
407
408
409
410
    checkCudaErrors(cudaGetLastError());

    delete [] local_ptr;
    delete [] global_ptr;
    delete [] local_global_ptr;
    checkCudaErrors(cudaGetLastError());
    for (long i = 0; i < n_groups; ++i) {
        cudaEventDestroy(input_ready[i]);
        cudaEventDestroy(output_ready[i]);
zms1999's avatar
zms1999 committed
411
        cudaEventDestroy(output_torch_ready[i]);
Rick Ho's avatar
Rick Ho committed
412
413
414
    }
    delete [] input_ready;
    delete [] output_ready;
zms1999's avatar
zms1999 committed
415
    delete [] output_torch_ready;
Rick Ho's avatar
Rick Ho committed
416
417
418
419
420
421
    for (long i = 0; i < num_expert; ++i) {
        if (stored_models[i + rank * num_expert]) {
            cudaEventDestroy(evt_reduce[i]);
        }
    }
    delete [] evt_reduce;
Rick Ho's avatar
Rick Ho committed
422
423
424
}

#endif  // SMART_SCHEDULE_H