moe_cuda_kernel.cu 9.05 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
// #define MOE_BREAKDOWN
22
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

24
thread_local CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
61
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
62
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
63
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
64
        const scalar_t* weight,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
65
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
66
67
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
68
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
69
70
        const size_t num_expert, 
		cublasOperation_t transb) {
Rick Ho's avatar
Rick Ho committed
71

Rick Ho's avatar
Rick Ho committed
72
	scalar_t *input_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73

Rick Ho's avatar
Rick Ho committed
74
75
76
77
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Rick Ho's avatar
Rick Ho committed
78

Rick Ho's avatar
Rick Ho committed
79
80
81
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82

Rick Ho's avatar
Rick Ho committed
83
84
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
85

Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
91
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
Rick Ho's avatar
Rick Ho committed
92
	}
Rick Ho's avatar
Rick Ho committed
93

94
95
96
97
98
99
100
101
102
103
104
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

	batch_scatter_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
105
		<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
106
				input_buf); 
Rick Ho's avatar
Rick Ho committed
107
	smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
108

Rick Ho's avatar
Rick Ho committed
109
110
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
111
112
113
114
115
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
116
		checkCudaErrors(cublasXgemm(smgr.handles[i],
Rick Ho's avatar
Rick Ho committed
117
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
118
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
119
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
120
				&alpha,
Rick Ho's avatar
Rick Ho committed
121
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
122
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
123
				&beta,
Rick Ho's avatar
Rick Ho committed
124
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
125
				));
Rick Ho's avatar
Rick Ho committed
126

Rick Ho's avatar
Rick Ho committed
127
128
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
129

130
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
131
		<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
132
				output); 
Rick Ho's avatar
Rick Ho committed
133
	smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
134

Rick Ho's avatar
Rick Ho committed
135
136
	cudaFree(input_buf);
	cudaFree(output_buf);
137
138
139
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
140
141
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
143
144
145
146
147
148
149
150
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
151
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
152
153
154
155
156

    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
157
        // checkCudaErrors(cublasSetStream);
Rick Ho's avatar
Rick Ho committed
158
        checkCudaErrors(cublasXgemm(smgr.handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
160
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
161
162
163
164
165
166
167
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
168
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
169
170
171
172
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
173
    for (size_t i=0; i<num_expert; ++i) {
174
        checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
175
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
176
177
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
178

Jiezhong Qiu's avatar
Jiezhong Qiu committed
179
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
180
181
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
182
        torch::Tensor weight
Rick Ho's avatar
Rick Ho committed
183
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
184
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
185
186
187
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
188
            
Rick Ho's avatar
Rick Ho committed
189
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
190
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
191
#endif
192
193
194
195
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
196
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
197
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
198
199
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
200
201
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
202
                    weight.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
203
204
205
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
206
                    out_feat,
Rick Ho's avatar
Rick Ho committed
207
208
                    num_expert,
					CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
209
210
211
212
213
214
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
215
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
216
217
218
219
220
221
222
223
224
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
225

Rick Ho's avatar
Rick Ho committed
226
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
228
#endif
229
230
231
232
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
233
234
235

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
236
237

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
238
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
239
240
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
241
242
243
244
245
246
247
248
249
250
251
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
252
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
253
254
255
256
257
258
259
260
261

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
262
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
263
264
265
266
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
267
268
269
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
270
271

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
272
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
273
274
275
276
277
278
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
279
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
280
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
281
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
282

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
283
284
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
285
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
286
287
288
289
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
290

Jiezhong Qiu's avatar
Jiezhong Qiu committed
291
292
293
294
295
296
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
297
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
298
299
300
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
301
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
302
303
304
305
306
307
308
309
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
310
}
Rick Ho's avatar
Rick Ho committed
311
*/