moe_cuda_kernel.cu 15.7 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

13
#ifdef MOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
14
#include <mpi.h>
15
16
#include <nccl.h>
#endif
Rick Ho's avatar
Rick Ho committed
17

Rick Ho's avatar
Rick Ho committed
18
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

20

Rick Ho's avatar
Rick Ho committed
21
22
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

Rick Ho's avatar
Rick Ho committed
24
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

36
37
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
38
void batch_scatter_kernel(size_t wid, const int* pos, 
39
40
41
42
43
44
45
46
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
47
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
48
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
49
50
51
52
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
53
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
54
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
55
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
56

Rick Ho's avatar
Rick Ho committed
57
58
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
59

Rick Ho's avatar
Rick Ho committed
60
61
62
63
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
64
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
65
66
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
67

68
69
70
71
72
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
73
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
74
75
76
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
77
78
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
79
80
81
	delete [] gate;
	delete [] expert_ptr;
}
82

83
84
#ifdef MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
void moe_cuda_expert_exchange_impl(
		const int* local_expert_count, 
		int* global_expert_count, 
		int* fwd_expert_count, 
		int num_expert, int world_size) {
	MPI_Alltoall(local_expert_count, num_expert, MPI_INT, 
			global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
	for (int i = 0; i < num_expert; ++i) {
		for (int j = 0; j < world_size; ++j) {
			fwd_expert_count[i] += global_expert_count[i + j * num_expert];
		}
	}
}

std::vector<torch::Tensor> moe_cuda_expert_exchange(
		torch::Tensor local_expert_count,
		long num_expert, long n_workers) {
    auto global_expert_count = torch::empty_like(local_expert_count);
	auto fwe_options = torch::TensorOptions()
		.dtype(local_expert_count.dtype());
    auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
	moe_cuda_expert_exchange_impl(
			local_expert_count.data_ptr<int>(),
			global_expert_count.data_ptr<int>(),
			fwd_expert_count.data_ptr<int>(),
			num_expert, n_workers);
	return {global_expert_count, fwd_expert_count};
}

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
	const scalar_t* local_input_buf,
	const int* local_expert_count,
	const int* global_expert_count,
	scalar_t* input_buf,
	size_t in_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
	// assert world_size > 1
	int recv_ptr = 0;
	/* TODO: may save for backward */
	int *expert_ptr = new int[num_expert * world_size];
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						local_input_buf + expert_ptr[idx] * in_feat, 
						local_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
			}
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						input_buf + recv_ptr * in_feat,
						global_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				recv_ptr += global_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
153
154
			}
		}
155
		NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
156
	}
157
	delete [] expert_ptr;
158
}
Rick Ho's avatar
Rick Ho committed
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
std::vector<torch::Tensor> moe_cuda_global_scatter(
		torch::Tensor input_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
	auto smgr = getCudaStreamManager(input_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), 
			"moe_cuda_global_scatter", ([&] {
		moe_cuda_global_scatter_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			local_expert_count.data_ptr<int>(),
			global_expert_count.data_ptr<int>(),
			global_input_buf.data_ptr<scalar_t>(),
			in_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {global_input_buf,};
}

template<typename scalar_t>
void moe_cuda_global_gather_impl(
	const scalar_t* output_buf,
	const int* local_expert_count,
	const int* global_expert_count,
	scalar_t* local_output_buf,
	size_t out_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
	int send_ptr = 0;
	/* TODO: may save for backward */
	int *expert_ptr = new int[num_expert * world_size];
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						output_buf + send_ptr * out_feat,
						global_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				send_ptr += global_expert_count[idx];
			}
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						local_output_buf + expert_ptr[idx] * out_feat, 
						local_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
222
223
			}
		}
224
		NCCL_SAFE_CALL(ncclGroupEnd());
225
	}
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
	delete [] expert_ptr;
}

std::vector<torch::Tensor> moe_cuda_global_gather(
		torch::Tensor output_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
	auto smgr = getCudaStreamManager(output_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), 
			"moe_cuda_global_gather", ([&] {
		moe_cuda_global_scatter_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			local_expert_count.data_ptr<int>(),
			global_expert_count.data_ptr<int>(),
			local_output_buf.data_ptr<scalar_t>(),
			out_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {local_output_buf,};
251
252
}

253

254
255
#endif  // MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
256
257
258
259
260
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
261
262
		const long batch_size,
		const long in_feat, 
263
		CudaStreamManager* smgr) {
264
	batch_scatter_kernel<scalar_t>
265
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
266
				input_buf); 
267
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
268
}
Rick Ho's avatar
Rick Ho committed
269

Rick Ho's avatar
Rick Ho committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
287
288
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
289
	batch_gather_kernel<scalar_t>
290
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
291
				output); 
292
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
293
}
Rick Ho's avatar
Rick Ho committed
294

Rick Ho's avatar
Rick Ho committed
295
296
297
298
299
300
301
302
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
303
304
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
305
306
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
307
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
308
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
309
310
311
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
312
313
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
314
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
315
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
316
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
317
				&alpha,
Rick Ho's avatar
Rick Ho committed
318
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
319
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
320
				&beta,
Rick Ho's avatar
Rick Ho committed
321
322
323
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
324
325
		ptr += expert_count[i];
	}
326
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
327
328
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
329
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
330
331
332
333
334
335
336
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
337
338
339
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
340
341
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
342
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
343

Rick Ho's avatar
Rick Ho committed
344
345
346
347
348
349
350
351
352
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
353
354
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
355
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
356
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
357
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
358
				&alpha,
Rick Ho's avatar
Rick Ho committed
359
360
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
361
				&beta,
Rick Ho's avatar
Rick Ho committed
362
363
364
365
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
366
367
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
368
369
370
371
372
373
374
375
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
376
				));
Rick Ho's avatar
Rick Ho committed
377

378
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
379
	}
380
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
381
}
Rick Ho's avatar
Rick Ho committed
382
383


Rick Ho's avatar
Rick Ho committed
384
std::vector<torch::Tensor> moe_cuda_expert_count(
385
386
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
404
405
}

Rick Ho's avatar
Rick Ho committed
406
407
408
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
409
	auto smgr = getCudaStreamManager(input.device().index());
Rick Ho's avatar
Rick Ho committed
410
411
412
413
414
415
416
417
418
419
420
421
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
422
423
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
424
425
426
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
427

Rick Ho's avatar
Rick Ho committed
428
429
430
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
431
	auto smgr = getCudaStreamManager(output_buf.device().index());
Rick Ho's avatar
Rick Ho committed
432
433
434
435
436
437
438
439
440
441
442
443
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
444
445
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
446
447
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
448
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
449

Jiezhong Qiu's avatar
Jiezhong Qiu committed
450
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
451
452
453
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
454
		) {
455
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
456
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
457
458
459
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
460
            
Rick Ho's avatar
Rick Ho committed
461
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
462
463
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
464
#endif
Rick Ho's avatar
Rick Ho committed
465
466
467
468
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
469
    
Rick Ho's avatar
Rick Ho committed
470
471
472
473
474
475
476
477
478
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
479
480
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
481
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
482
483
484
485
486
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
487
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
488
489
490
491
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
492
) {
493
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
494
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
495
496
497
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
498

Rick Ho's avatar
Rick Ho committed
499
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
500
501
502
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
503
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
504

Rick Ho's avatar
Rick Ho committed
505
506
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
507

Rick Ho's avatar
Rick Ho committed
508
509
510
511
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
512
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
513
514
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
515
516
517
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
518
            out_feat,
519
520
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
521
522
523
        );
    }));

Rick Ho's avatar
Rick Ho committed
524
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
525
526
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
527
528

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
529
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
530
531
532
533
534
535
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
536
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
537
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
538
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
539

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
540
541
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
542
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
543
544
545
546
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
547

Jiezhong Qiu's avatar
Jiezhong Qiu committed
548
549
550
551
552
553
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
554
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
555
556
557
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
558
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
559
560
561
562
563
564
565
566
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
567
}
Rick Ho's avatar
Rick Ho committed
568
*/