moe_cuda_kernel.cu 10.4 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
// #define MOE_BREAKDOWN
22
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

24
thread_local CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36

37
38
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
39
void batch_scatter_kernel(size_t wid, const int* pos, 
40
41
42
43
44
45
46
47
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
48
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
49
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
50
51
52
53
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
54
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
55
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
56
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57

Rick Ho's avatar
Rick Ho committed
58
59
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
60

Rick Ho's avatar
Rick Ho committed
61
62
63
64
65
66
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
Rick Ho's avatar
Rick Ho committed
67
	}
Rick Ho's avatar
Rick Ho committed
68

69
70
71
72
73
74
75
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
76
77
	delete [] gate;
	delete [] expert_ptr;
78

Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
86
87
88
	ENSURE_SMGR(smgr, num_expert);
}

template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
		const size_t batch_size,
		const size_t in_feat) {
89
	batch_scatter_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
90
		<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
91
				input_buf); 
Rick Ho's avatar
Rick Ho committed
92
	smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
}

template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
		const size_t out_feat) {
	batch_gather_kernel<scalar_t>
		<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
				output); 
	smgr.sync(0);
}

template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
127
        const size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
128
129
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
130
131
132
133
134
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
135
		checkCudaErrors(cublasXgemm(smgr.handles[i],
Rick Ho's avatar
Rick Ho committed
136
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
137
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
138
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
139
				&alpha,
Rick Ho's avatar
Rick Ho committed
140
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
141
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
142
				&beta,
Rick Ho's avatar
Rick Ho committed
143
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
144
				));
Rick Ho's avatar
Rick Ho committed
145

Rick Ho's avatar
Rick Ho committed
146
147
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
148
	smgr.sync();
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
149
150
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
151
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
152
153
154
155
156
157
158
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
160
161
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
162
        const size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
163
164
	ENSURE_SMGR(smgr, num_expert);
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
165

Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
		checkCudaErrors(cublasXgemm(smgr.handles[i],
				CUBLAS_OP_N,
				CUBLAS_OP_N,
				in_feat, expert_count[i], out_feat,
				&alpha,
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
		checkCudaErrors(cublasXgemm(smgr.handles[i],
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
				));

		ptr += expert_count[i];
	}
	smgr.sync();
Jiezhong Qiu's avatar
Jiezhong Qiu committed
201
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
202

Rick Ho's avatar
Rick Ho committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

std::vector<torch::Tensor> moe_cuda_expert_count(
		torch::Tensor weight,
		torch::Tensor gate) {
	const auto batch_size = gate.size(0);
	const auto num_expert = weight.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
}

std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
			in_feat);
	}));
	return {input_buf,};
}

std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
			out_feat);
	}));
	return {output,};
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
267
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
268
269
270
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
271
		) {
Rick Ho's avatar
Rick Ho committed
272
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
273
274
275
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
276
            
Rick Ho's avatar
Rick Ho committed
277
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
278
279
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
280
#endif
Rick Ho's avatar
Rick Ho committed
281
	/*
282
283
284
285
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Rick Ho's avatar
Rick Ho committed
286
287
288
289
290
	*/
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
291
    
Rick Ho's avatar
Rick Ho committed
292
293
294
295
296
297
298
299
300
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
Rick Ho's avatar
Rick Ho committed
301
			num_expert
Rick Ho's avatar
Rick Ho committed
302
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
303
304
305
306
307
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
308
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
309
310
311
312
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
313
) {
Rick Ho's avatar
Rick Ho committed
314
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
315
316
317
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
318

Rick Ho's avatar
Rick Ho committed
319
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
320
321
322
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
323
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
324

Rick Ho's avatar
Rick Ho committed
325
326
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
327

Rick Ho's avatar
Rick Ho committed
328
329
330
331
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
332
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
333
334
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
335
336
337
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
338
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339
340
341
342
            num_expert
        );
    }));

Rick Ho's avatar
Rick Ho committed
343
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
344
345
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
346
347

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
348
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
349
350
351
352
353
354
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
355
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
356
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
357
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
358

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
359
360
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
361
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
362
363
364
365
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
366

Jiezhong Qiu's avatar
Jiezhong Qiu committed
367
368
369
370
371
372
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
373
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
374
375
376
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
377
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
378
379
380
381
382
383
384
385
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
386
}
Rick Ho's avatar
Rick Ho committed
387
*/