moe_compute_kernel.cu 10.9 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
22
		const long* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
25
26
27
28
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
32
		const scalar_t* inbuf, scalar_t* oubuf) { 
33
34
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
35
36
37
38
39
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

/*
	This function is to be called with one block per each column
*/
template <typename scalar_t>
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
	int m /* lines */, int n /* columns*/) {
    extern __shared__ float sdata[];
    unsigned int tid = threadIdx.x; // line
    unsigned int i = blockIdx.x + threadIdx.x * n; // get to idx th line
    unsigned int offset = 0;
    unsigned int it = n * blockDim.x; // advanced blockDim.x threads vertically

    // sum all the values from that column to fit in one single block
    sdata[tid] = 0;
    while (i + offset < n*m) {
        sdata[tid] += matrix[i + offset];
        offset += it; 
        
    }
    __syncthreads();

    for (unsigned int s = 1; tid + s < blockDim.x; s *= 2) {
        if (tid % (2*s) == 0) {
            sdata[tid] += sdata[tid + s];
        }

        __syncthreads();
    }
    if (tid == 0) {result[blockIdx.x] = sdata[0];}

}

Rick Ho's avatar
Rick Ho committed
74
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
75
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
76
77
78
79
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
80
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
81
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
82
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
83

Rick Ho's avatar
Rick Ho committed
84
85
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
86

Rick Ho's avatar
Rick Ho committed
87
88
89
90
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
91
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
92
93
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
94

95
96
97
98
99
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
100
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
101
102
103
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
104
105
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
106
107
108
	delete [] gate;
	delete [] expert_ptr;
}
109

Rick Ho's avatar
Rick Ho committed
110
111
112
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
113
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
114
		scalar_t* input_buf,
115
116
		const long batch_size,
		const long in_feat, 
117
		CudaStreamManager* smgr) {
118
	batch_scatter_kernel<scalar_t>
119
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
120
				input_buf); 
121
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
122
}
Rick Ho's avatar
Rick Ho committed
123

Rick Ho's avatar
Rick Ho committed
124
125
template <typename scalar_t>
__global__
126
void batch_gather_kernel(size_t wid, const long* pos, 
Rick Ho's avatar
Rick Ho committed
127
		const scalar_t* inbuf, scalar_t* oubuf) { 
128
129
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
Rick Ho's avatar
Rick Ho committed
130
131
132
133
134
135
136
137
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
138
		const long* d_pos,
Rick Ho's avatar
Rick Ho committed
139
140
		scalar_t* output,
		const size_t batch_size,
141
142
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
143
	batch_gather_kernel<scalar_t>
144
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
145
				output); 
146
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
147
}
Rick Ho's avatar
Rick Ho committed
148

Rick Ho's avatar
Rick Ho committed
149
150
151
152
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
153
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
154
        scalar_t* output_buf,
155
		const bool has_bias,
Rick Ho's avatar
Rick Ho committed
156
157
        const size_t in_feat,
        const size_t out_feat,
158
159
        const size_t num_expert,
		CudaStreamManager* smgr) {
160
	scalar_t alpha = 1, beta = has_bias ? 1 : 0; 
Rick Ho's avatar
Rick Ho committed
161

Rick Ho's avatar
Rick Ho committed
162
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
163
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
164
165
166
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
167
168
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
169
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
170
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
171
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
172
				&alpha,
Rick Ho's avatar
Rick Ho committed
173
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
174
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
175
				&beta,
Rick Ho's avatar
Rick Ho committed
176
177
178
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
179
180
		ptr += expert_count[i];
	}
181
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
182
183
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
184
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
185
186
187
188
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
189
		const long* expert_count,
Rick Ho's avatar
Rick Ho committed
190
191
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
192
193
		scalar_t* grad_bias,
		const bool has_bias,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
194
195
196
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
197
198
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
199
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
200

Rick Ho's avatar
Rick Ho committed
201
202
203
204
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
205
			cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
Rick Ho's avatar
Rick Ho committed
206
207
208
209
210
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
211
212
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
213
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
214
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
215
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
216
				&alpha,
Rick Ho's avatar
Rick Ho committed
217
218
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
219
				&beta,
Rick Ho's avatar
Rick Ho committed
220
221
222
223
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
224
225
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
226
227
228
229
230
231
232
233
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
234
				));
235
236
		
		if (has_bias) {
237
238
239
240
241
242
243
244
			column_reduce
			<<<out_feat, 1024, sizeof(scalar_t)*1024, smgr->stream(0)>>>
			(
				grad_output_buf + ptr * out_feat,
				grad_bias + i * out_feat,
				expert_count[i],
				out_feat
			);
245
		}
Rick Ho's avatar
Rick Ho committed
246

247
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
248
	}
249
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
250
}
Rick Ho's avatar
Rick Ho committed
251
252


Rick Ho's avatar
Rick Ho committed
253
std::vector<torch::Tensor> moe_cuda_expert_count(
254
255
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
273
274
}

Rick Ho's avatar
Rick Ho committed
275
276
277
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
278
	auto smgr = getCudaStreamManager(input.device().index());
279
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
280
281
    const auto in_feat = input.size(1);

282
283
284
285
	auto opt = torch::TensorOptions()
		.dtype(input.dtype())
		.device(input.device());
	auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
286

Rick Ho's avatar
Rick Ho committed
287
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
Rick Ho's avatar
Rick Ho committed
288
289
290
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
291
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
292
293
			input_buf.data_ptr<scalar_t>(),
			batch_size,
294
295
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
296
297
298
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
299

Rick Ho's avatar
Rick Ho committed
300
301
302
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
303
	auto smgr = getCudaStreamManager(output_buf.device().index());
304
	const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
305
306
    const auto out_feat = output_buf.size(1);

307
308
309
310
	auto opt = torch::TensorOptions()
		.dtype(output_buf.dtype())
		.device(output_buf.device());
	auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
311

Rick Ho's avatar
Rick Ho committed
312
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
Rick Ho's avatar
Rick Ho committed
313
314
315
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
316
			pos.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
317
318
			output.data_ptr<scalar_t>(),
			batch_size,
319
320
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
321
322
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
323
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
324

Jiezhong Qiu's avatar
Jiezhong Qiu committed
325
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
326
        torch::Tensor input_buf,
327
		torch::Tensor expert_count,
328
329
        torch::Tensor weight,
		at::optional<torch::Tensor> bias
Rick Ho's avatar
Rick Ho committed
330
		) {
331
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
332
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
333
334
335
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
336
            
Rick Ho's avatar
Rick Ho committed
337
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
338
339
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
340
#endif
341
342

    torch::Tensor output;
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
343
    
344
345
346
347
348
349
350
351
352
	if (bias.has_value()) {
		output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
	} else{
		auto out_options = torch::TensorOptions()
			.device(input_buf.device())
			.dtype(input_buf.dtype());
		output = torch::empty({batch_size, out_feat}, out_options);
	}
		
Rick Ho's avatar
Rick Ho committed
353
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
Rick Ho's avatar
Rick Ho committed
354
355
356
357
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
358
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
359
			output.data_ptr<scalar_t>(),
360
			bias.has_value(),
Rick Ho's avatar
Rick Ho committed
361
362
			in_feat,
			out_feat,
363
364
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
365
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
366
367
368
369
370
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
371
std::vector<torch::Tensor> moe_cuda_backward(
372
373
374
375
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
	torch::Tensor expert_count,
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
376
	at::optional<torch::Tensor> bias
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
377
) {
378
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
379
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
380
381
382
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
383

Rick Ho's avatar
Rick Ho committed
384
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
385
386
387
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
388
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
389

Rick Ho's avatar
Rick Ho committed
390
391
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
392
	auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
393

Rick Ho's avatar
Rick Ho committed
394
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
395
396
397
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
398
            weight.data_ptr<scalar_t>(),
399
			expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
400
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
401
            grad_weight.data_ptr<scalar_t>(),
402
403
			grad_bias.data_ptr<scalar_t>(),
			bias.has_value(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
404
405
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
406
            out_feat,
407
408
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
409
410
411
        );
    }));

412
	return {grad_input_buf, grad_weight, grad_bias};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
413
}