moe_compute_kernel.cu 9.85 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
22
		const long* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
25
26
27
28
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
32
		const scalar_t* inbuf, scalar_t* oubuf) { 
33
34
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
35
36
37
38
39
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
40
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
41
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
42
43
44
45
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
46
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
47
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
48
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49

Rick Ho's avatar
Rick Ho committed
50
51
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
52

Rick Ho's avatar
Rick Ho committed
53
54
55
56
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
57
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
58
59
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
60

61
62
63
64
65
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
66
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
67
68
69
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
70
71
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
72
73
74
	delete [] gate;
	delete [] expert_ptr;
}
75

Rick Ho's avatar
Rick Ho committed
76
77
78
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
79
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
80
		scalar_t* input_buf,
81
82
		const long batch_size,
		const long in_feat, 
83
		CudaStreamManager* smgr) {
84
	batch_scatter_kernel<scalar_t>
85
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
86
				input_buf); 
87
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
88
}
Rick Ho's avatar
Rick Ho committed
89

Rick Ho's avatar
Rick Ho committed
90
91
template <typename scalar_t>
__global__
92
void batch_gather_kernel(size_t wid, const long* pos, 
Rick Ho's avatar
Rick Ho committed
93
		const scalar_t* inbuf, scalar_t* oubuf) { 
94
95
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
Rick Ho's avatar
Rick Ho committed
96
97
98
99
100
101
102
103
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
104
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
105
106
		scalar_t* output,
		const size_t batch_size,
107
108
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
109
	batch_gather_kernel<scalar_t>
110
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
111
				output); 
112
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
113
}
Rick Ho's avatar
Rick Ho committed
114

Rick Ho's avatar
Rick Ho committed
115
116
117
118
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
119
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
120
121
122
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
123
124
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
125
126
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
127
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
128
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
129
130
131
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
132
133
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
134
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
135
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
136
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
137
				&alpha,
Rick Ho's avatar
Rick Ho committed
138
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
139
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
140
				&beta,
Rick Ho's avatar
Rick Ho committed
141
142
143
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
144
145
		ptr += expert_count[i];
	}
146
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
147
148
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
150
151
152
153
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
154
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
155
156
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
157
158
159
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
160
161
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
162
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
163

Rick Ho's avatar
Rick Ho committed
164
165
166
167
168
169
170
171
172
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
173
174
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
175
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
176
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
177
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
178
				&alpha,
Rick Ho's avatar
Rick Ho committed
179
180
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
181
				&beta,
Rick Ho's avatar
Rick Ho committed
182
183
184
185
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
186
187
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
188
189
190
191
192
193
194
195
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
196
				));
Rick Ho's avatar
Rick Ho committed
197

198
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
199
	}
200
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
201
}
Rick Ho's avatar
Rick Ho committed
202
203


Rick Ho's avatar
Rick Ho committed
204
std::vector<torch::Tensor> moe_cuda_expert_count(
205
206
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
224
225
}

Rick Ho's avatar
Rick Ho committed
226
227
228
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
229
	auto smgr = getCudaStreamManager(input.device().index());
230
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
231
232
    const auto in_feat = input.size(1);

233
234
235
236
	auto opt = torch::TensorOptions()
		.dtype(input.dtype())
		.device(input.device());
	auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
237

Rick Ho's avatar
Rick Ho committed
238
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
Rick Ho's avatar
Rick Ho committed
239
240
241
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
242
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
243
244
			input_buf.data_ptr<scalar_t>(),
			batch_size,
245
246
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
247
248
249
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
250

Rick Ho's avatar
Rick Ho committed
251
252
253
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
254
	auto smgr = getCudaStreamManager(output_buf.device().index());
255
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
256
257
    const auto out_feat = output_buf.size(1);

258
259
260
261
	auto opt = torch::TensorOptions()
		.dtype(output_buf.dtype())
		.device(output_buf.device());
	auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
262

Rick Ho's avatar
Rick Ho committed
263
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
Rick Ho's avatar
Rick Ho committed
264
265
266
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
267
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
268
269
			output.data_ptr<scalar_t>(),
			batch_size,
270
271
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
272
273
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
274
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
275

Jiezhong Qiu's avatar
Jiezhong Qiu committed
276
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
277
        torch::Tensor input_buf,
278
279
		torch::Tensor expert_count,
        torch::Tensor weight
Rick Ho's avatar
Rick Ho committed
280
		) {
281
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
282
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
283
284
285
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
286
            
Rick Ho's avatar
Rick Ho committed
287
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
288
289
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
290
#endif
Rick Ho's avatar
Rick Ho committed
291
292
293
294
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
295
    
Rick Ho's avatar
Rick Ho committed
296
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
Rick Ho's avatar
Rick Ho committed
297
298
299
300
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
301
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
302
303
304
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
305
306
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
307
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
308
309
310
311
312
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
313
std::vector<torch::Tensor> moe_cuda_backward(
314
315
316
317
318
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
	torch::Tensor expert_count,
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
	bool has_bias 			
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
319
) {
320
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
321
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
322
323
324
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
325

Rick Ho's avatar
Rick Ho committed
326
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
327
328
329
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
330
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
331

Rick Ho's avatar
Rick Ho committed
332
333
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
334

Rick Ho's avatar
Rick Ho committed
335
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
336
337
338
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339
            weight.data_ptr<scalar_t>(),
340
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
341
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
342
343
344
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
345
            out_feat,
346
347
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
348
349
350
        );
    }));

351
352
353
354
355
356
357
358
359
360
361
362
363
	if (!has_bias) return {grad_input_buf, grad_weight, torch::empty({num_expert,out_feat})};

	// weight and input have been concatenated. need to split the grads back
	// and separate them into input, weight, bias
	torch::Tensor grad_orig_input_buf = at::narrow(grad_input_buf, -1, 0, in_feat - 1).contiguous();

	// bias is also squeezed in the new added dimension
	torch::Tensor grad_orig_bias = at::narrow(grad_weight, -1, in_feat - 1, 1).squeeze(2).contiguous();
	torch::Tensor grad_orig_weight = at::narrow(grad_weight, -1, 0, in_feat - 1).contiguous();
	
	return {grad_orig_input_buf, grad_orig_weight, grad_orig_bias};


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
364
}