moe_compute_kernel.cu 13.1 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
TiagoMAntunes's avatar
TiagoMAntunes committed
22
23
24
25
26
        const long* offset, const scalar_t** ptrs) { 
    size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
    if (idx < n) {
        ptrs[idx] = base + stride * offset[idx];
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
32
33
34
35
36
37
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * pos[blockIdx.x];
    oubuf += wid * blockIdx.x;
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
38
39
}

40
41

/*
TiagoMAntunes's avatar
TiagoMAntunes committed
42
    This function is to be called with one block per each column
43
44
45
46
*/
template <typename scalar_t>
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
TiagoMAntunes's avatar
TiagoMAntunes committed
47
    int m /* lines */, int n /* columns*/) {
48
49
50
51
52
53
    
    // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
    extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[];
    scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);

    
54
55
    unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x; // line
    unsigned int i = threadIdx.x * n + threadIdx.y + blockIdx.y * blockDim.y; // get to idx th line
56
    unsigned int offset = 0;
57
58
    unsigned int it = n * blockDim.x; // advance blockDim.x threads vertically
    unsigned int real_y = blockIdx.y * blockDim.y + threadIdx.y;
59
60
61

    // sum all the values from that column to fit in one single block
    sdata[tid] = 0;
62
63
64
65
66
67
    if (real_y < n && threadIdx.x < m) // remember we only have one x block
        while (i + offset < n*m) {
            sdata[tid] += matrix[i + offset];
            offset += it; 
            
        }
68
69
    __syncthreads();

70
71
72
73
74
75
76
77
    unsigned int lowest = blockDim.x > m ? m : blockDim.x;
    if (real_y < n && threadIdx.x < m)
        for (unsigned int s = 1; threadIdx.x + s < lowest; s *= 2) {
            if (threadIdx.x % (2*s) == 0) {
                sdata[tid] += sdata[tid + s];
            }

            __syncthreads();
78
79
        }

80
81
    if (threadIdx.x == 0 && real_y < n) {
        result[real_y] = sdata[tid];
82
83
84
    }
}

Rick Ho's avatar
Rick Ho committed
85
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
86
        const int* d_gate,
TiagoMAntunes's avatar
TiagoMAntunes committed
87
88
89
        int* expert_count,
        int* d_pos,
        const size_t num_expert,
Rick Ho's avatar
Rick Ho committed
90
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
91
    int *gate = new int[batch_size];
TiagoMAntunes's avatar
TiagoMAntunes committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    int *expert_ptr = new int[num_expert];
    memset(expert_count, 0, sizeof(int) * num_expert);

    checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
                cudaMemcpyDeviceToHost));

    for (int i = 0; i < batch_size; ++i) {
        ++expert_count[gate[i]];
    }
    expert_ptr[0] = 0;
    for (int i = 1; i < num_expert; ++i) {
        expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
    }

    int *pos = new int[batch_size];

    for (int i = 0; i < batch_size; ++i) {
        pos[i] = expert_ptr[gate[i]]++;
    }
    for (int i = num_expert - 1; i > 0; --i) {
        expert_ptr[i] = expert_ptr[i - 1];
    }
    expert_ptr[0] = 0;
    checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
                cudaMemcpyHostToDevice));
    delete [] gate;
    delete [] expert_ptr;
Rick Ho's avatar
Rick Ho committed
119
}
120

Rick Ho's avatar
Rick Ho committed
121
122
123
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
TiagoMAntunes's avatar
TiagoMAntunes committed
124
125
126
127
128
129
130
131
132
        const long* d_pos,
        scalar_t* input_buf,
        const long batch_size,
        const long in_feat, 
        CudaStreamManager* smgr) {
    batch_scatter_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
                input_buf); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
133
}
Rick Ho's avatar
Rick Ho committed
134

Rick Ho's avatar
Rick Ho committed
135
136
template <typename scalar_t>
__global__
137
void batch_gather_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
138
139
140
141
142
143
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * blockIdx.x;
    oubuf += wid * pos[blockIdx.x];
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
Rick Ho's avatar
Rick Ho committed
144
145
146
147
148
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
149
150
151
152
153
154
155
156
157
        const long* d_pos,
        scalar_t* output,
        const size_t batch_size,
        const size_t out_feat,
        CudaStreamManager* smgr) {
    batch_gather_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
                output); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
158
}
Rick Ho's avatar
Rick Ho committed
159

Rick Ho's avatar
Rick Ho committed
160
161
162
163
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
164
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
165
        scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
166
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
167
168
        const size_t in_feat,
        const size_t out_feat,
169
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        CudaStreamManager* smgr) {
    scalar_t alpha = 1, beta = has_bias ? 1 : 0; 

    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_T,
                CUBLAS_OP_N,
                out_feat, expert_count[i], in_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                input_buf + ptr * in_feat, in_feat,
                &beta,
                output_buf + out_feat * ptr, out_feat
                ));

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
193
194
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
195
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
196
197
198
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
199
200
        const scalar_t* weight,
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
201
202
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
203
204
        scalar_t* grad_bias,
        const bool has_bias,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
205
206
207
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
208
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
209
        CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
210
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
211

212
213
214
215
216
    // bias
    dim3 block_threads(32, 32);
    dim3 grid_threads(1, out_feat / 32 + (out_feat % 32 ? 1 : 0));
    

TiagoMAntunes's avatar
TiagoMAntunes committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
                    sizeof(scalar_t) * in_feat * out_feat);
            cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C

        // Backward input: g_i = w @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_N,
                in_feat, expert_count[i], out_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_input_buf + in_feat * ptr, in_feat
                ));

        // Backward weight: g_w = i @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_T,
                in_feat, out_feat, expert_count[i],
                &alpha,
                input_buf + in_feat * ptr, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_weight + i * in_feat * out_feat, in_feat
                ));
        
        if (has_bias) {
            column_reduce
254
            <<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
TiagoMAntunes's avatar
TiagoMAntunes committed
255
256
257
258
259
260
261
262
263
264
265
            (
                grad_output_buf + ptr * out_feat,
                grad_bias + i * out_feat,
                expert_count[i],
                out_feat
            );
        }

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
266
}
Rick Ho's avatar
Rick Ho committed
267
268


Rick Ho's avatar
Rick Ho committed
269
std::vector<torch::Tensor> moe_cuda_expert_count(
TiagoMAntunes's avatar
TiagoMAntunes committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        torch::Tensor gate, 
        size_t num_expert) {
    const auto batch_size = gate.size(0);

    auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
    auto expert_count = torch::empty(num_expert, ec_options);

    auto pos_options = torch::TensorOptions()
        .device(gate.device())
        .dtype(torch::kInt32);
    auto pos = torch::empty(batch_size, pos_options);
    moe_cuda_expert_count_impl(
            gate.data_ptr<int>(),
            expert_count.data_ptr<int>(),
            pos.data_ptr<int>(),
            num_expert,
            batch_size);

    return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
289
290
}

Rick Ho's avatar
Rick Ho committed
291
292
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
TiagoMAntunes's avatar
TiagoMAntunes committed
293
294
295
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(input.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
296
297
    const auto in_feat = input.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
298
299
300
301
    auto opt = torch::TensorOptions()
        .dtype(input.dtype())
        .device(input.device());
    auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
302

Rick Ho's avatar
Rick Ho committed
303
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
304
305
306
307
308
309
310
311
312
313
            ([&] {
        moe_cuda_local_scatter_impl<scalar_t>(
            input.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            input_buf.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
            smgr);
    }));
    return {input_buf,};
Rick Ho's avatar
Rick Ho committed
314
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
315

Rick Ho's avatar
Rick Ho committed
316
std::vector<torch::Tensor> moe_cuda_local_gather(
TiagoMAntunes's avatar
TiagoMAntunes committed
317
318
319
320
    torch::Tensor output_buf,
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(output_buf.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
321
322
    const auto out_feat = output_buf.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
323
324
325
326
    auto opt = torch::TensorOptions()
        .dtype(output_buf.dtype())
        .device(output_buf.device());
    auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
327

Rick Ho's avatar
Rick Ho committed
328
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
329
330
331
332
333
334
335
336
337
338
            ([&] {
        moe_cuda_local_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            smgr);
    }));
    return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
339
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
340

Jiezhong Qiu's avatar
Jiezhong Qiu committed
341
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
342
        torch::Tensor input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
343
        torch::Tensor expert_count,
344
        torch::Tensor weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
345
346
347
348
        at::optional<torch::Tensor> bias
        ) {
    auto smgr = getCudaStreamManager(input_buf.device().index());
    const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
349
350
351
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
352
            
Rick Ho's avatar
Rick Ho committed
353
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
354
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
TiagoMAntunes's avatar
TiagoMAntunes committed
355
            num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
356
#endif
357
358

    torch::Tensor output;
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
359
    
TiagoMAntunes's avatar
TiagoMAntunes committed
360
361
362
363
364
365
366
367
368
    if (bias.has_value()) {
        output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
    } else{
        auto out_options = torch::TensorOptions()
            .device(input_buf.device())
            .dtype(input_buf.dtype());
        output = torch::empty({batch_size, out_feat}, out_options);
    }
        
Rick Ho's avatar
Rick Ho committed
369
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
370
371
372
373
374
375
376
377
378
379
380
381
            ([&] {
        moe_cuda_forward_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            weight.data_ptr<scalar_t>(),
            expert_count.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            bias.has_value(),
            in_feat,
            out_feat,
            num_expert,
            smgr
        );
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
382
383
384
385
386
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
387
std::vector<torch::Tensor> moe_cuda_backward(
388
389
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
390
    torch::Tensor expert_count,
391
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
392
    at::optional<torch::Tensor> bias
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
393
) {
TiagoMAntunes's avatar
TiagoMAntunes committed
394
    auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
395
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
396
397
398
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
399

Rick Ho's avatar
Rick Ho committed
400
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
401
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
TiagoMAntunes's avatar
TiagoMAntunes committed
402
403
            "out_feat (d_ffn)=%ld\n",
            batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
404
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
405

Rick Ho's avatar
Rick Ho committed
406
407
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
TiagoMAntunes's avatar
TiagoMAntunes committed
408
    auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
409

Rick Ho's avatar
Rick Ho committed
410
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
411
412
413
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
414
            weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
415
            expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
416
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
417
            grad_weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
418
419
            grad_bias.data_ptr<scalar_t>(),
            bias.has_value(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
420
421
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
422
            out_feat,
423
            num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
424
            smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
425
426
427
        );
    }));

TiagoMAntunes's avatar
TiagoMAntunes committed
428
    return {grad_input_buf, grad_weight, grad_bias};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
429
}