moe_cuda_kernel.cu 10.8 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
// #define MOE_BREAKDOWN
22
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

24
thread_local CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36

37
38
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
39
void batch_scatter_kernel(size_t wid, const int* pos, 
40
41
42
43
44
45
46
47
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
48
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
49
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
50
51
52
53
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
54
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
55
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
56
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57

Rick Ho's avatar
Rick Ho committed
58
59
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
60

Rick Ho's avatar
Rick Ho committed
61
62
63
64
65
66
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
Rick Ho's avatar
Rick Ho committed
67
	}
Rick Ho's avatar
Rick Ho committed
68

69
70
71
72
73
74
75
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
76
77
	delete [] gate;
	delete [] expert_ptr;
78

Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
86
87
88
	ENSURE_SMGR(smgr, num_expert);
}

template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
		const size_t batch_size,
		const size_t in_feat) {
89
	batch_scatter_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
90
		<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
91
				input_buf); 
Rick Ho's avatar
Rick Ho committed
92
	smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
}

template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
		const size_t out_feat) {
	batch_gather_kernel<scalar_t>
		<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
				output); 
	smgr.sync(0);
}

template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
        const size_t num_expert, 
		cublasOperation_t transb) {
Rick Ho's avatar
Rick Ho committed
129

Rick Ho's avatar
Rick Ho committed
130
131
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
132
133
134
135
136
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
137
		checkCudaErrors(cublasXgemm(smgr.handles[i],
Rick Ho's avatar
Rick Ho committed
138
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
139
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
140
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
141
				&alpha,
Rick Ho's avatar
Rick Ho committed
142
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
143
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
144
				&beta,
Rick Ho's avatar
Rick Ho committed
145
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
146
				));
Rick Ho's avatar
Rick Ho committed
147

Rick Ho's avatar
Rick Ho committed
148
149
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
150
	smgr.sync();
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
151
152
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
153
154
155
156
157
158
159
160
161
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
162
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
163
164
165
166
167

    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
168
        // checkCudaErrors(cublasSetStream);
Rick Ho's avatar
Rick Ho committed
169
        checkCudaErrors(cublasXgemm(smgr.handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
170
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
171
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
172
173
174
175
176
177
178
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
179
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
180
181
182
183
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
184
    for (size_t i=0; i<num_expert; ++i) {
185
        checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
186
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
187
188
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
189

Rick Ho's avatar
Rick Ho committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

std::vector<torch::Tensor> moe_cuda_expert_count(
		torch::Tensor weight,
		torch::Tensor gate) {
	const auto batch_size = gate.size(0);
	const auto num_expert = weight.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
}

std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
			in_feat);
	}));
	return {input_buf,};
}

std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
			out_feat);
	}));
	return {output,};
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
254
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
255
256
257
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
258
		) {
Rick Ho's avatar
Rick Ho committed
259
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
260
261
262
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
263
            
Rick Ho's avatar
Rick Ho committed
264
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
265
266
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
267
#endif
Rick Ho's avatar
Rick Ho committed
268
	/*
269
270
271
272
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Rick Ho's avatar
Rick Ho committed
273
274
275
276
277
	*/
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
278
    
Rick Ho's avatar
Rick Ho committed
279
280
281
282
283
284
285
286
287
288
289
290
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
			num_expert,
			CUBLAS_OP_T
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
291
292
293
294
295
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
296
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
297
298
299
300
301
302
303
304
305
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
306

Rick Ho's avatar
Rick Ho committed
307
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
308
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
309
#endif
310
311
312
313
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
314
315
316

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
317
318

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
319
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
320
321
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
322
323
324
325
326
327
328
329
330
331
332
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
333
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
334
335
336
337
338
339
340
341
342

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
343
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
344
345
346
347
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
348
349
350
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
351
352

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
353
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
354
355
356
357
358
359
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
360
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
361
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
362
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
363

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
364
365
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
366
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
367
368
369
370
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
371

Jiezhong Qiu's avatar
Jiezhong Qiu committed
372
373
374
375
376
377
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
378
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
379
380
381
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
382
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
383
384
385
386
387
388
389
390
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
391
}
Rick Ho's avatar
Rick Ho committed
392
*/