moe_compute_kernel.cu 9.76 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
22
		const long* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
25
26
27
28
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
32
		const scalar_t* inbuf, scalar_t* oubuf) { 
33
34
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
35
36
37
38
39
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
40
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
41
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
42
43
44
45
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
46
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
47
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
48
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49

Rick Ho's avatar
Rick Ho committed
50
51
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
52

Rick Ho's avatar
Rick Ho committed
53
54
55
56
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
57
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
58
59
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
60

61
62
63
64
65
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
66
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
67
68
69
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
70
71
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
72
73
74
	delete [] gate;
	delete [] expert_ptr;
}
75

Rick Ho's avatar
Rick Ho committed
76
77
78
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
79
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
80
		scalar_t* input_buf,
81
82
		const long batch_size,
		const long in_feat, 
83
		CudaStreamManager* smgr) {
84
	batch_scatter_kernel<scalar_t>
85
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
86
				input_buf); 
87
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
88
}
Rick Ho's avatar
Rick Ho committed
89

Rick Ho's avatar
Rick Ho committed
90
91
template <typename scalar_t>
__global__
92
void batch_gather_kernel(size_t wid, const long* pos, 
Rick Ho's avatar
Rick Ho committed
93
		const scalar_t* inbuf, scalar_t* oubuf) { 
94
95
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
Rick Ho's avatar
Rick Ho committed
96
97
98
99
100
101
102
103
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
104
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
105
106
		scalar_t* output,
		const size_t batch_size,
107
108
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
109
	batch_gather_kernel<scalar_t>
110
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
111
				output); 
112
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
113
}
Rick Ho's avatar
Rick Ho committed
114

Rick Ho's avatar
Rick Ho committed
115
116
117
118
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
119
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
120
        scalar_t* output_buf,
121
		const bool has_bias,
Rick Ho's avatar
Rick Ho committed
122
123
        const size_t in_feat,
        const size_t out_feat,
124
125
        const size_t num_expert,
		CudaStreamManager* smgr) {
126
	scalar_t alpha = 1, beta = has_bias ? 1 : 0; 
Rick Ho's avatar
Rick Ho committed
127

Rick Ho's avatar
Rick Ho committed
128
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
129
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
130
131
132
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
133
134
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
135
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
136
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
137
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
138
				&alpha,
Rick Ho's avatar
Rick Ho committed
139
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
140
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
141
				&beta,
Rick Ho's avatar
Rick Ho committed
142
143
144
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
145
146
		ptr += expert_count[i];
	}
147
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
148
149
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
150
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
151
152
153
154
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
155
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
156
157
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
158
159
		scalar_t* grad_bias,
		const bool has_bias,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
160
161
162
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
163
164
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
165
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
166

Rick Ho's avatar
Rick Ho committed
167
168
169
170
171
172
173
174
175
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
176
177
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
178
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
179
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
180
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
181
				&alpha,
Rick Ho's avatar
Rick Ho committed
182
183
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
184
				&beta,
Rick Ho's avatar
Rick Ho committed
185
186
187
188
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
189
190
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
191
192
193
194
195
196
197
198
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
199
				));
200
201
202
203
		
		if (has_bias) {
			// call bias kernel here
		}
Rick Ho's avatar
Rick Ho committed
204

205
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
206
	}
207
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
208
}
Rick Ho's avatar
Rick Ho committed
209
210


Rick Ho's avatar
Rick Ho committed
211
std::vector<torch::Tensor> moe_cuda_expert_count(
212
213
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
231
232
}

Rick Ho's avatar
Rick Ho committed
233
234
235
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
236
	auto smgr = getCudaStreamManager(input.device().index());
237
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
238
239
    const auto in_feat = input.size(1);

240
241
242
243
	auto opt = torch::TensorOptions()
		.dtype(input.dtype())
		.device(input.device());
	auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
244

Rick Ho's avatar
Rick Ho committed
245
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
Rick Ho's avatar
Rick Ho committed
246
247
248
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
249
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
250
251
			input_buf.data_ptr<scalar_t>(),
			batch_size,
252
253
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
254
255
256
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
257

Rick Ho's avatar
Rick Ho committed
258
259
260
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
261
	auto smgr = getCudaStreamManager(output_buf.device().index());
262
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
263
264
    const auto out_feat = output_buf.size(1);

265
266
267
268
	auto opt = torch::TensorOptions()
		.dtype(output_buf.dtype())
		.device(output_buf.device());
	auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
269

Rick Ho's avatar
Rick Ho committed
270
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
Rick Ho's avatar
Rick Ho committed
271
272
273
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
274
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
275
276
			output.data_ptr<scalar_t>(),
			batch_size,
277
278
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
279
280
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
281
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
282

Jiezhong Qiu's avatar
Jiezhong Qiu committed
283
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
284
        torch::Tensor input_buf,
285
		torch::Tensor expert_count,
286
287
        torch::Tensor weight,
		at::optional<torch::Tensor> bias
Rick Ho's avatar
Rick Ho committed
288
		) {
289
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
290
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
291
292
293
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
294
            
Rick Ho's avatar
Rick Ho committed
295
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
296
297
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
298
#endif
299
300

    torch::Tensor output;
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
301
    
302
303
304
305
306
307
308
309
310
	if (bias.has_value()) {
		output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
	} else{
		auto out_options = torch::TensorOptions()
			.device(input_buf.device())
			.dtype(input_buf.dtype());
		output = torch::empty({batch_size, out_feat}, out_options);
	}
		
Rick Ho's avatar
Rick Ho committed
311
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
Rick Ho's avatar
Rick Ho committed
312
313
314
315
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
316
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
317
			output.data_ptr<scalar_t>(),
318
			bias.has_value(),
Rick Ho's avatar
Rick Ho committed
319
320
			in_feat,
			out_feat,
321
322
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
323
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
324
325
326
327
328
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
329
std::vector<torch::Tensor> moe_cuda_backward(
330
331
332
333
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
	torch::Tensor expert_count,
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
334
	at::optional<torch::Tensor> bias
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
335
) {
336
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
337
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
338
339
340
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
341

Rick Ho's avatar
Rick Ho committed
342
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
343
344
345
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
346
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
347

Rick Ho's avatar
Rick Ho committed
348
349
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
350
	auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
351

Rick Ho's avatar
Rick Ho committed
352
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
353
354
355
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
356
            weight.data_ptr<scalar_t>(),
357
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
358
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
359
            grad_weight.data_ptr<scalar_t>(),
360
361
			grad_bias.data_ptr<scalar_t>(),
			bias.has_value(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
362
363
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
364
            out_feat,
365
366
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
367
368
369
        );
    }));

370
	return {grad_input_buf, grad_weight, grad_bias};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
371
}