moe_cuda_kernel.cu 8.31 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
14
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

21
thread_local CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
25
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
28
29
30
31
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
32

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
34
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
35
        const scalar_t* input,
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36
        const int* gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
37
38
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
39
40
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
41
42
        const size_t out_feat,
        const size_t num_expert,
43
        cublasOperation_t transb) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44

45
    checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47

    // setup Aarray, Barray and Carray
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
	std::vector<const scalar_t*> aptrs;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
51
52
53
    std::vector<scalar_t*> cptrs;
	
    const scalar_t **Aarray;
    const scalar_t **Barray;
    scalar_t **Carray;
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
54
55
56
	checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
58

	for (size_t i=0; i<batch_size; ++i) {
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
59
60
        aptrs.push_back(input + in_feat * i);
        cptrs.push_back(output + out_feat * i);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
61
	}
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
70
	checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
					scalar_t*), cudaMemcpyHostToDevice));
	// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
	// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
				sizeof(scalar_t*), cudaMemcpyHostToDevice));

	dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
	generate_ptr_offset_kernel<<<griddim, blockdim, 0,
71
		*(smgr.streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
Rick Ho's avatar
Rick Ho committed
72
73
74

	scalar_t alpha = 1, beta = 0; 

75
	checkCudaErrors(cublasXgemmBatched(smgr.handle,
Rick Ho's avatar
Rick Ho committed
76
77
78
79
80
81
82
83
84
				CUBLAS_OP_N,
				transb,
				1, out_feat, in_feat,
				&alpha,
				Aarray, 1,
				Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
				&beta,
				Carray, 1,
				batch_size));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
85

86
    checkCudaErrors(cudaStreamSynchronize(*(smgr.streams)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
87
88
89
    checkCudaErrors(cudaFree(Aarray));
    checkCudaErrors(cudaFree(Barray));
    checkCudaErrors(cudaFree(Carray));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
90
91
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
92
93
94
95
96
97
98
99
100
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
101
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
102
103
104
105
106

    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
107
108
        checkCudaErrors(cublasSetStream(smgr.handle, *(smgr.streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(smgr.handle,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
109
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
110
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
111
112
113
114
115
116
117
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
118
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
119
120
121
122
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
123
    for (size_t i=0; i<num_expert; ++i) {
124
        checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
125
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
126
127
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
128

Jiezhong Qiu's avatar
Jiezhong Qiu committed
129
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
130
131
132
133
134
135
136
137
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
138
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
139
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
140
#endif
141
142
143
144
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
145
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
146
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
147
148
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
149
150
151
152
153
154
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
155
156
                    out_feat,
                    num_expert,
157
                    CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
158
159
160
161
162
163
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
164
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
165
166
167
168
169
170
171
172
173
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
174

Jiezhong Qiu's avatar
Jiezhong Qiu committed
175
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
176
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
177
#endif
178
179
180
181
182
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
183
184
    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
185
186
187
188

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
189
190
191
192
193
194
195
196
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
197
            CUBLAS_OP_N
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
198
199
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
200
201
202
203
204
205
206
207
208

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
            out_feat,
210
            num_expert
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
211
212
213
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
214
215
216
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
217
218

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
219
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
220
221
222
223
224
225
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
226
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
228
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
229

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
230
231
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
232
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
233
234
235
236
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
237

Jiezhong Qiu's avatar
Jiezhong Qiu committed
238
239
240
241
242
243
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
244
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
245
246
247
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
248
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
249
250
251
252
253
254
255
256
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
257
}
Rick Ho's avatar
Rick Ho committed
258
*/