moe_compute_kernel.cu 10.7 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20
21
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
22
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
23
		const long* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
26
27
28
29
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

30
31
template <typename scalar_t>
__global__
32
void batch_scatter_kernel(size_t wid, const long* pos, 
33
34
35
36
37
38
39
40
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
41
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
42
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
43
44
45
46
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
47
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
48
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
49
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50

Rick Ho's avatar
Rick Ho committed
51
52
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
53

Rick Ho's avatar
Rick Ho committed
54
55
56
57
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
58
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
59
60
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
61

62
63
64
65
66
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
67
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
68
69
70
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
71
72
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
73
74
75
	delete [] gate;
	delete [] expert_ptr;
}
76

Rick Ho's avatar
Rick Ho committed
77
78
79
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
80
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
81
		scalar_t* input_buf,
82
83
		const long batch_size,
		const long in_feat, 
84
		CudaStreamManager* smgr) {
85
	batch_scatter_kernel<scalar_t>
86
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
87
				input_buf); 
88
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
89
}
Rick Ho's avatar
Rick Ho committed
90

Rick Ho's avatar
Rick Ho committed
91
92
template <typename scalar_t>
__global__
93
void batch_gather_kernel(size_t wid, const long* pos, 
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
100
101
102
103
104
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
105
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
106
107
		scalar_t* output,
		const size_t batch_size,
108
109
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
110
	batch_gather_kernel<scalar_t>
111
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
112
				output); 
113
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
114
}
Rick Ho's avatar
Rick Ho committed
115

Rick Ho's avatar
Rick Ho committed
116
117
118
119
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
120
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
121
122
123
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
124
125
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
126
127
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
128
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
129
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
130
131
132
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
133
134
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
135
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
136
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
137
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
138
				&alpha,
Rick Ho's avatar
Rick Ho committed
139
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
140
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
141
				&beta,
Rick Ho's avatar
Rick Ho committed
142
143
144
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
145
146
		ptr += expert_count[i];
	}
147
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
148
149
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
150
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
151
152
153
154
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
155
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
156
157
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
158
159
160
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
161
162
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
163
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
164

Rick Ho's avatar
Rick Ho committed
165
166
167
168
169
170
171
172
173
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
174
175
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
176
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
177
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
178
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
179
				&alpha,
Rick Ho's avatar
Rick Ho committed
180
181
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
182
				&beta,
Rick Ho's avatar
Rick Ho committed
183
184
185
186
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
187
188
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
189
190
191
192
193
194
195
196
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
197
				));
Rick Ho's avatar
Rick Ho committed
198

199
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
200
	}
201
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
202
}
Rick Ho's avatar
Rick Ho committed
203
204


Rick Ho's avatar
Rick Ho committed
205
std::vector<torch::Tensor> moe_cuda_expert_count(
206
207
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
225
226
}

Rick Ho's avatar
Rick Ho committed
227
228
229
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
230
	auto smgr = getCudaStreamManager(input.device().index());
231
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
232
233
    const auto in_feat = input.size(1);

234
235
236
237
	auto opt = torch::TensorOptions()
		.dtype(input.dtype())
		.device(input.device());
	auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
238

Rick Ho's avatar
Rick Ho committed
239
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
Rick Ho's avatar
Rick Ho committed
240
241
242
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
243
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
244
245
			input_buf.data_ptr<scalar_t>(),
			batch_size,
246
247
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
248
249
250
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
251

Rick Ho's avatar
Rick Ho committed
252
253
254
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
255
	auto smgr = getCudaStreamManager(output_buf.device().index());
256
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
257
258
    const auto out_feat = output_buf.size(1);

259
260
261
262
	auto opt = torch::TensorOptions()
		.dtype(output_buf.dtype())
		.device(output_buf.device());
	auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
263

Rick Ho's avatar
Rick Ho committed
264
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
Rick Ho's avatar
Rick Ho committed
265
266
267
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
268
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
269
270
			output.data_ptr<scalar_t>(),
			batch_size,
271
272
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
273
274
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
275
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
276

Jiezhong Qiu's avatar
Jiezhong Qiu committed
277
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
278
279
280
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
281
		) {
282
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
283
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
284
285
286
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
287
            
Rick Ho's avatar
Rick Ho committed
288
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
289
290
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
291
#endif
Rick Ho's avatar
Rick Ho committed
292
293
294
295
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
296
    
Rick Ho's avatar
Rick Ho committed
297
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
Rick Ho's avatar
Rick Ho committed
298
299
300
301
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
302
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
303
304
305
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
306
307
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
308
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
309
310
311
312
313
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
314
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
315
316
317
318
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
319
) {
320
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
321
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
322
323
324
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
325

Rick Ho's avatar
Rick Ho committed
326
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
327
328
329
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
330
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
331

Rick Ho's avatar
Rick Ho committed
332
333
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
334

Rick Ho's avatar
Rick Ho committed
335
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
336
337
338
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339
            weight.data_ptr<scalar_t>(),
340
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
341
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
342
343
344
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
345
            out_feat,
346
347
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
348
349
350
        );
    }));

Rick Ho's avatar
Rick Ho committed
351
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
352
353
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
354
355

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
356
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
357
358
359
360
361
362
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
363
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
364
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
365
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
366

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
367
368
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
369
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
370
371
372
373
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
374

Jiezhong Qiu's avatar
Jiezhong Qiu committed
375
376
377
378
379
380
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
381
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
382
383
384
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
385
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
386
387
388
389
390
391
392
393
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
394
}
Rick Ho's avatar
Rick Ho committed
395
*/