moe_compute_kernel.cu 10.5 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20
21
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
22
23
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
26
27
28
29
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

30
31
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
32
void batch_scatter_kernel(size_t wid, const int* pos, 
33
34
35
36
37
38
39
40
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
41
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
42
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
43
44
45
46
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
47
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
48
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
49
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50

Rick Ho's avatar
Rick Ho committed
51
52
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
53

Rick Ho's avatar
Rick Ho committed
54
55
56
57
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
58
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
59
60
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
61

62
63
64
65
66
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
67
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
68
69
70
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
71
72
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
73
74
75
	delete [] gate;
	delete [] expert_ptr;
}
76

Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
82
83
		const long batch_size,
		const long in_feat, 
84
		CudaStreamManager* smgr) {
85
	batch_scatter_kernel<scalar_t>
86
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
87
				input_buf); 
88
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
89
}
Rick Ho's avatar
Rick Ho committed
90

Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
108
109
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
110
	batch_gather_kernel<scalar_t>
111
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
112
				output); 
113
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
114
}
Rick Ho's avatar
Rick Ho committed
115

Rick Ho's avatar
Rick Ho committed
116
117
118
119
120
121
122
123
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
124
125
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
126
127
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
128
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
129
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
130
131
132
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
133
134
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
135
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
136
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
137
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
138
				&alpha,
Rick Ho's avatar
Rick Ho committed
139
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
140
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
141
				&beta,
Rick Ho's avatar
Rick Ho committed
142
143
144
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
145
146
		ptr += expert_count[i];
	}
147
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
148
149
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
150
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
151
152
153
154
155
156
157
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
158
159
160
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
161
162
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
163
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
164

Rick Ho's avatar
Rick Ho committed
165
166
167
168
169
170
171
172
173
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
174
175
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
176
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
177
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
178
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
179
				&alpha,
Rick Ho's avatar
Rick Ho committed
180
181
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
182
				&beta,
Rick Ho's avatar
Rick Ho committed
183
184
185
186
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
187
188
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
189
190
191
192
193
194
195
196
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
197
				));
Rick Ho's avatar
Rick Ho committed
198

199
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
200
	}
201
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
202
}
Rick Ho's avatar
Rick Ho committed
203
204


Rick Ho's avatar
Rick Ho committed
205
std::vector<torch::Tensor> moe_cuda_expert_count(
206
207
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
225
226
}

Rick Ho's avatar
Rick Ho committed
227
228
229
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
230
	auto smgr = getCudaStreamManager(input.device().index());
Rick Ho's avatar
Rick Ho committed
231
232
233
234
235
236
237
238
239
240
241
242
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
243
244
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
245
246
247
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
248

Rick Ho's avatar
Rick Ho committed
249
250
251
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
252
	auto smgr = getCudaStreamManager(output_buf.device().index());
Rick Ho's avatar
Rick Ho committed
253
254
255
256
257
258
259
260
261
262
263
264
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
265
266
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
267
268
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
269
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
270

Jiezhong Qiu's avatar
Jiezhong Qiu committed
271
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
272
273
274
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
275
		) {
276
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
277
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
278
279
280
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
281
            
Rick Ho's avatar
Rick Ho committed
282
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
283
284
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
285
#endif
Rick Ho's avatar
Rick Ho committed
286
287
288
289
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
290
    
Rick Ho's avatar
Rick Ho committed
291
292
293
294
295
296
297
298
299
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
300
301
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
302
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
303
304
305
306
307
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
308
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
309
310
311
312
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
313
) {
314
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
315
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
316
317
318
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
319

Rick Ho's avatar
Rick Ho committed
320
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
321
322
323
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
324
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
325

Rick Ho's avatar
Rick Ho committed
326
327
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
328

Rick Ho's avatar
Rick Ho committed
329
330
331
332
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
333
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
334
335
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
336
337
338
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339
            out_feat,
340
341
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
342
343
344
        );
    }));

Rick Ho's avatar
Rick Ho committed
345
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
346
347
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
348
349

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
350
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
351
352
353
354
355
356
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
357
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
358
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
359
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
360

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
361
362
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
363
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
364
365
366
367
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
368

Jiezhong Qiu's avatar
Jiezhong Qiu committed
369
370
371
372
373
374
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
375
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
376
377
378
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
379
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
380
381
382
383
384
385
386
387
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
388
}
Rick Ho's avatar
Rick Ho committed
389
*/