moe_compute_kernel.cu 12.9 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
TiagoMAntunes's avatar
TiagoMAntunes committed
22
23
24
25
26
        const long* offset, const scalar_t** ptrs) { 
    size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
    if (idx < n) {
        ptrs[idx] = base + stride * offset[idx];
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
32
33
34
35
36
37
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * pos[blockIdx.x];
    oubuf += wid * blockIdx.x;
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
38
39
}

40
41

/*
TiagoMAntunes's avatar
TiagoMAntunes committed
42
    This function is to be called with one block per each column
43
44
45
46
*/
template <typename scalar_t>
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
TiagoMAntunes's avatar
TiagoMAntunes committed
47
    int m /* lines */, int n /* columns*/) {
48
    extern __shared__ float sdata[];
49
50
    unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x; // line
    unsigned int i = threadIdx.x * n + threadIdx.y + blockIdx.y * blockDim.y; // get to idx th line
51
    unsigned int offset = 0;
52
53
    unsigned int it = n * blockDim.x; // advance blockDim.x threads vertically
    unsigned int real_y = blockIdx.y * blockDim.y + threadIdx.y;
54
55
56

    // sum all the values from that column to fit in one single block
    sdata[tid] = 0;
57
58
59
60
61
62
    if (real_y < n && threadIdx.x < m) // remember we only have one x block
        while (i + offset < n*m) {
            sdata[tid] += matrix[i + offset];
            offset += it; 
            
        }
63
64
    __syncthreads();

65
66
67
68
69
70
71
72
    unsigned int lowest = blockDim.x > m ? m : blockDim.x;
    if (real_y < n && threadIdx.x < m)
        for (unsigned int s = 1; threadIdx.x + s < lowest; s *= 2) {
            if (threadIdx.x % (2*s) == 0) {
                sdata[tid] += sdata[tid + s];
            }

            __syncthreads();
73
74
        }

75
76
    if (threadIdx.x == 0 && real_y < n) {
        result[real_y] = sdata[tid];
77
78
79
    }
}

Rick Ho's avatar
Rick Ho committed
80
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
81
        const int* d_gate,
TiagoMAntunes's avatar
TiagoMAntunes committed
82
83
84
        int* expert_count,
        int* d_pos,
        const size_t num_expert,
Rick Ho's avatar
Rick Ho committed
85
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
86
    int *gate = new int[batch_size];
TiagoMAntunes's avatar
TiagoMAntunes committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    int *expert_ptr = new int[num_expert];
    memset(expert_count, 0, sizeof(int) * num_expert);

    checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
                cudaMemcpyDeviceToHost));

    for (int i = 0; i < batch_size; ++i) {
        ++expert_count[gate[i]];
    }
    expert_ptr[0] = 0;
    for (int i = 1; i < num_expert; ++i) {
        expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
    }

    int *pos = new int[batch_size];

    for (int i = 0; i < batch_size; ++i) {
        pos[i] = expert_ptr[gate[i]]++;
    }
    for (int i = num_expert - 1; i > 0; --i) {
        expert_ptr[i] = expert_ptr[i - 1];
    }
    expert_ptr[0] = 0;
    checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
                cudaMemcpyHostToDevice));
    delete [] gate;
    delete [] expert_ptr;
Rick Ho's avatar
Rick Ho committed
114
}
115

Rick Ho's avatar
Rick Ho committed
116
117
118
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
TiagoMAntunes's avatar
TiagoMAntunes committed
119
120
121
122
123
124
125
126
127
        const long* d_pos,
        scalar_t* input_buf,
        const long batch_size,
        const long in_feat, 
        CudaStreamManager* smgr) {
    batch_scatter_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
                input_buf); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
128
}
Rick Ho's avatar
Rick Ho committed
129

Rick Ho's avatar
Rick Ho committed
130
131
template <typename scalar_t>
__global__
132
void batch_gather_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
133
134
135
136
137
138
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * blockIdx.x;
    oubuf += wid * pos[blockIdx.x];
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
Rick Ho's avatar
Rick Ho committed
139
140
141
142
143
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
144
145
146
147
148
149
150
151
152
        const long* d_pos,
        scalar_t* output,
        const size_t batch_size,
        const size_t out_feat,
        CudaStreamManager* smgr) {
    batch_gather_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
                output); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
153
}
Rick Ho's avatar
Rick Ho committed
154

Rick Ho's avatar
Rick Ho committed
155
156
157
158
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
159
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
160
        scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
161
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
162
163
        const size_t in_feat,
        const size_t out_feat,
164
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        CudaStreamManager* smgr) {
    scalar_t alpha = 1, beta = has_bias ? 1 : 0; 

    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_T,
                CUBLAS_OP_N,
                out_feat, expert_count[i], in_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                input_buf + ptr * in_feat, in_feat,
                &beta,
                output_buf + out_feat * ptr, out_feat
                ));

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
188
189
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
190
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
191
192
193
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
194
195
        const scalar_t* weight,
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
196
197
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
198
199
        scalar_t* grad_bias,
        const bool has_bias,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
200
201
202
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
203
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
204
        CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
205
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
206

207
208
209
210
211
    // bias
    dim3 block_threads(32, 32);
    dim3 grid_threads(1, out_feat / 32 + (out_feat % 32 ? 1 : 0));
    

TiagoMAntunes's avatar
TiagoMAntunes committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
                    sizeof(scalar_t) * in_feat * out_feat);
            cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C

        // Backward input: g_i = w @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_N,
                in_feat, expert_count[i], out_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_input_buf + in_feat * ptr, in_feat
                ));

        // Backward weight: g_w = i @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_T,
                in_feat, out_feat, expert_count[i],
                &alpha,
                input_buf + in_feat * ptr, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_weight + i * in_feat * out_feat, in_feat
                ));
        
        if (has_bias) {
            column_reduce
249
            <<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
TiagoMAntunes's avatar
TiagoMAntunes committed
250
251
252
253
254
255
256
257
258
259
260
            (
                grad_output_buf + ptr * out_feat,
                grad_bias + i * out_feat,
                expert_count[i],
                out_feat
            );
        }

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
261
}
Rick Ho's avatar
Rick Ho committed
262
263


Rick Ho's avatar
Rick Ho committed
264
std::vector<torch::Tensor> moe_cuda_expert_count(
TiagoMAntunes's avatar
TiagoMAntunes committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        torch::Tensor gate, 
        size_t num_expert) {
    const auto batch_size = gate.size(0);

    auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
    auto expert_count = torch::empty(num_expert, ec_options);

    auto pos_options = torch::TensorOptions()
        .device(gate.device())
        .dtype(torch::kInt32);
    auto pos = torch::empty(batch_size, pos_options);
    moe_cuda_expert_count_impl(
            gate.data_ptr<int>(),
            expert_count.data_ptr<int>(),
            pos.data_ptr<int>(),
            num_expert,
            batch_size);

    return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
284
285
}

Rick Ho's avatar
Rick Ho committed
286
287
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
TiagoMAntunes's avatar
TiagoMAntunes committed
288
289
290
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(input.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
291
292
    const auto in_feat = input.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
293
294
295
296
    auto opt = torch::TensorOptions()
        .dtype(input.dtype())
        .device(input.device());
    auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
297

Rick Ho's avatar
Rick Ho committed
298
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
299
300
301
302
303
304
305
306
307
308
            ([&] {
        moe_cuda_local_scatter_impl<scalar_t>(
            input.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            input_buf.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
            smgr);
    }));
    return {input_buf,};
Rick Ho's avatar
Rick Ho committed
309
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
310

Rick Ho's avatar
Rick Ho committed
311
std::vector<torch::Tensor> moe_cuda_local_gather(
TiagoMAntunes's avatar
TiagoMAntunes committed
312
313
314
315
    torch::Tensor output_buf,
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(output_buf.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
316
317
    const auto out_feat = output_buf.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
318
319
320
321
    auto opt = torch::TensorOptions()
        .dtype(output_buf.dtype())
        .device(output_buf.device());
    auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
322

Rick Ho's avatar
Rick Ho committed
323
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
324
325
326
327
328
329
330
331
332
333
            ([&] {
        moe_cuda_local_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            smgr);
    }));
    return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
334
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
335

Jiezhong Qiu's avatar
Jiezhong Qiu committed
336
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
337
        torch::Tensor input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
338
        torch::Tensor expert_count,
339
        torch::Tensor weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
340
341
342
343
        at::optional<torch::Tensor> bias
        ) {
    auto smgr = getCudaStreamManager(input_buf.device().index());
    const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
344
345
346
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
347
            
Rick Ho's avatar
Rick Ho committed
348
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
349
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
TiagoMAntunes's avatar
TiagoMAntunes committed
350
            num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
351
#endif
352
353

    torch::Tensor output;
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
354
    
TiagoMAntunes's avatar
TiagoMAntunes committed
355
356
357
358
359
360
361
362
363
    if (bias.has_value()) {
        output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
    } else{
        auto out_options = torch::TensorOptions()
            .device(input_buf.device())
            .dtype(input_buf.dtype());
        output = torch::empty({batch_size, out_feat}, out_options);
    }
        
Rick Ho's avatar
Rick Ho committed
364
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
365
366
367
368
369
370
371
372
373
374
375
376
            ([&] {
        moe_cuda_forward_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            weight.data_ptr<scalar_t>(),
            expert_count.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            bias.has_value(),
            in_feat,
            out_feat,
            num_expert,
            smgr
        );
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
377
378
379
380
381
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
382
std::vector<torch::Tensor> moe_cuda_backward(
383
384
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
385
    torch::Tensor expert_count,
386
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
387
    at::optional<torch::Tensor> bias
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
388
) {
TiagoMAntunes's avatar
TiagoMAntunes committed
389
    auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
390
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
391
392
393
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
394

Rick Ho's avatar
Rick Ho committed
395
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
396
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
TiagoMAntunes's avatar
TiagoMAntunes committed
397
398
            "out_feat (d_ffn)=%ld\n",
            batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
399
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
400

Rick Ho's avatar
Rick Ho committed
401
402
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
TiagoMAntunes's avatar
TiagoMAntunes committed
403
    auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
404

Rick Ho's avatar
Rick Ho committed
405
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
406
407
408
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
409
            weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
410
            expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
411
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
412
            grad_weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
413
414
            grad_bias.data_ptr<scalar_t>(),
            bias.has_value(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
415
416
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
417
            out_feat,
418
            num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
419
            smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
420
421
422
        );
    }));

TiagoMAntunes's avatar
TiagoMAntunes committed
423
    return {grad_input_buf, grad_weight, grad_bias};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
424
}