moe_cuda_kernel.cu 10.5 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
// #define MOE_BREAKDOWN
22
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

Rick Ho's avatar
Rick Ho committed
24
25
26
// thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
29
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
30
31
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
33
34
35
36
37
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
38

39
40
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
41
void batch_scatter_kernel(size_t wid, const int* pos, 
42
43
44
45
46
47
48
49
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
50
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
51
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
52
53
54
55
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
56
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
57
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
58
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59

Rick Ho's avatar
Rick Ho committed
60
61
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
62

Rick Ho's avatar
Rick Ho committed
63
64
65
66
67
68
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
Rick Ho's avatar
Rick Ho committed
69
	}
Rick Ho's avatar
Rick Ho committed
70

71
72
73
74
75
76
77
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
78
79
	delete [] gate;
	delete [] expert_ptr;
80

Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
87
88
89
90
	ENSURE_SMGR(smgr, num_expert);
}

template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
		const size_t batch_size,
		const size_t in_feat) {
91
	batch_scatter_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
92
		<<<batch_size, 256, 0, smgr.stream(0)>>>(in_feat, d_pos, input,
93
				input_buf); 
Rick Ho's avatar
Rick Ho committed
94
	smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
}

template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
		const size_t out_feat) {
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
116
		<<<batch_size, 256, 0, smgr.stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
117
118
119
120
121
122
123
124
125
126
127
128
				output); 
	smgr.sync(0);
}

template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
129
        const size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
130
131
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
132
133
134
135
136
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
137
		checkCudaErrors(cublasXgemm(smgr.handles[i],
Rick Ho's avatar
Rick Ho committed
138
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
139
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
140
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
141
				&alpha,
Rick Ho's avatar
Rick Ho committed
142
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
143
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
144
				&beta,
Rick Ho's avatar
Rick Ho committed
145
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
146
				));
Rick Ho's avatar
Rick Ho committed
147

Rick Ho's avatar
Rick Ho committed
148
149
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
150
	smgr.sync();
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
151
152
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
153
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
154
155
156
157
158
159
160
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
161
162
163
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
164
        const size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
165
166
	ENSURE_SMGR(smgr, num_expert);
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
167

Rick Ho's avatar
Rick Ho committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
		checkCudaErrors(cublasXgemm(smgr.handles[i],
				CUBLAS_OP_N,
				CUBLAS_OP_N,
				in_feat, expert_count[i], out_feat,
				&alpha,
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
		checkCudaErrors(cublasXgemm(smgr.handles[i],
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
				));

		ptr += expert_count[i];
	}
	smgr.sync();
Jiezhong Qiu's avatar
Jiezhong Qiu committed
203
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
204

Rick Ho's avatar
Rick Ho committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

std::vector<torch::Tensor> moe_cuda_expert_count(
		torch::Tensor weight,
		torch::Tensor gate) {
	const auto batch_size = gate.size(0);
	const auto num_expert = weight.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
}

std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
			in_feat);
	}));
	return {input_buf,};
}

std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
			out_feat);
	}));
	return {output,};
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
269
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
270
271
272
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
273
		) {
Rick Ho's avatar
Rick Ho committed
274
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
275
276
277
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
278
            
Rick Ho's avatar
Rick Ho committed
279
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
280
281
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
282
#endif
Rick Ho's avatar
Rick Ho committed
283
	/*
284
285
286
287
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Rick Ho's avatar
Rick Ho committed
288
289
290
291
292
	*/
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
293
    
Rick Ho's avatar
Rick Ho committed
294
295
296
297
298
299
300
301
302
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
Rick Ho's avatar
Rick Ho committed
303
			num_expert
Rick Ho's avatar
Rick Ho committed
304
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
305
306
307
308
309
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
310
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
311
312
313
314
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
315
) {
Rick Ho's avatar
Rick Ho committed
316
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
317
318
319
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
320

Rick Ho's avatar
Rick Ho committed
321
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
322
323
324
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
325
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
326

Rick Ho's avatar
Rick Ho committed
327
328
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
329

Rick Ho's avatar
Rick Ho committed
330
331
332
333
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
334
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
335
336
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
337
338
339
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
340
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
341
342
343
344
            num_expert
        );
    }));

Rick Ho's avatar
Rick Ho committed
345
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
346
347
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
348
349

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
350
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
351
352
353
354
355
356
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
357
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
358
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
359
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
360

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
361
362
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
363
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
364
365
366
367
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
368

Jiezhong Qiu's avatar
Jiezhong Qiu committed
369
370
371
372
373
374
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
375
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
376
377
378
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
379
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
380
381
382
383
384
385
386
387
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
388
}
Rick Ho's avatar
Rick Ho committed
389
*/