moe_cuda_kernel.cu 14.7 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

13
#ifdef MOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
14
#include <mpi.h>
15
16
#include <nccl.h>
#endif
Rick Ho's avatar
Rick Ho committed
17

Rick Ho's avatar
Rick Ho committed
18
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

20

Rick Ho's avatar
Rick Ho committed
21
22
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

Rick Ho's avatar
Rick Ho committed
24
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

36
37
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
38
void batch_scatter_kernel(size_t wid, const int* pos, 
39
40
41
42
43
44
45
46
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
47
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
48
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
49
50
51
52
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
53
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
54
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
55
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
56

Rick Ho's avatar
Rick Ho committed
57
58
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
59

Rick Ho's avatar
Rick Ho committed
60
61
62
63
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
64
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
65
66
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
67

68
69
70
71
72
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
73
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
74
75
76
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
77
78
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
79
80
81
	delete [] gate;
	delete [] expert_ptr;
}
82

83
84
#ifdef MOE_USE_NCCL

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
	const scalar_t* local_input_buf,
	const int* local_expert_count,
	const int* global_expert_count,
	scalar_t* input_buf,
	size_t in_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
	// assert world_size > 1
	int recv_ptr = 0;
	/* TODO: may save for backward */
	int *expert_ptr = new int[num_expert * world_size];
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						local_input_buf + expert_ptr[idx] * in_feat, 
						local_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
			}
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						input_buf + recv_ptr * in_feat,
						global_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				recv_ptr += global_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
124
125
			}
		}
126
		NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
127
	}
128
	delete [] expert_ptr;
129
}
Rick Ho's avatar
Rick Ho committed
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
std::vector<torch::Tensor> moe_cuda_global_scatter(
		torch::Tensor input_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
	auto smgr = getCudaStreamManager(input_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), 
			"moe_cuda_global_scatter", ([&] {
		moe_cuda_global_scatter_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			local_expert_count.data_ptr<int>(),
			global_expert_count.data_ptr<int>(),
			global_input_buf.data_ptr<scalar_t>(),
			in_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {global_input_buf,};
}

template<typename scalar_t>
void moe_cuda_global_gather_impl(
	const scalar_t* output_buf,
	const int* local_expert_count,
	const int* global_expert_count,
	scalar_t* local_output_buf,
	size_t out_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
	int send_ptr = 0;
	/* TODO: may save for backward */
	int *expert_ptr = new int[num_expert * world_size];
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						output_buf + send_ptr * out_feat,
						global_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				send_ptr += global_expert_count[idx];
			}
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						local_output_buf + expert_ptr[idx] * out_feat, 
						local_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
193
194
			}
		}
195
		NCCL_SAFE_CALL(ncclGroupEnd());
196
	}
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
	delete [] expert_ptr;
}

std::vector<torch::Tensor> moe_cuda_global_gather(
		torch::Tensor output_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
	auto smgr = getCudaStreamManager(output_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), 
			"moe_cuda_global_gather", ([&] {
		moe_cuda_global_scatter_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			local_expert_count.data_ptr<int>(),
			global_expert_count.data_ptr<int>(),
			local_output_buf.data_ptr<scalar_t>(),
			out_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {local_output_buf,};
222
223
}

224

225
226
#endif  // MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
227
228
229
230
231
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
232
233
		const long batch_size,
		const long in_feat, 
234
		CudaStreamManager* smgr) {
235
	batch_scatter_kernel<scalar_t>
236
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
237
				input_buf); 
238
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
239
}
Rick Ho's avatar
Rick Ho committed
240

Rick Ho's avatar
Rick Ho committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
258
259
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
260
	batch_gather_kernel<scalar_t>
261
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
262
				output); 
263
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
264
}
Rick Ho's avatar
Rick Ho committed
265

Rick Ho's avatar
Rick Ho committed
266
267
268
269
270
271
272
273
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
274
275
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
276
277
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
278
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
279
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
280
281
282
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
283
284
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
285
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
286
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
287
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
288
				&alpha,
Rick Ho's avatar
Rick Ho committed
289
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
290
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
291
				&beta,
Rick Ho's avatar
Rick Ho committed
292
293
294
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
295
296
		ptr += expert_count[i];
	}
297
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
298
299
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
300
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
301
302
303
304
305
306
307
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
308
309
310
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
311
312
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
313
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
314

Rick Ho's avatar
Rick Ho committed
315
316
317
318
319
320
321
322
323
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
324
325
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
326
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
327
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
328
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
329
				&alpha,
Rick Ho's avatar
Rick Ho committed
330
331
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
332
				&beta,
Rick Ho's avatar
Rick Ho committed
333
334
335
336
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
337
338
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
339
340
341
342
343
344
345
346
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
347
				));
Rick Ho's avatar
Rick Ho committed
348

349
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
350
	}
351
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
352
}
Rick Ho's avatar
Rick Ho committed
353
354


Rick Ho's avatar
Rick Ho committed
355
std::vector<torch::Tensor> moe_cuda_expert_count(
356
357
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
375
376
}

Rick Ho's avatar
Rick Ho committed
377
378
379
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
380
	auto smgr = getCudaStreamManager(input.device().index());
Rick Ho's avatar
Rick Ho committed
381
382
383
384
385
386
387
388
389
390
391
392
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
393
394
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
395
396
397
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
398

Rick Ho's avatar
Rick Ho committed
399
400
401
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
402
	auto smgr = getCudaStreamManager(output_buf.device().index());
Rick Ho's avatar
Rick Ho committed
403
404
405
406
407
408
409
410
411
412
413
414
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
415
416
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
417
418
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
419
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
420

Jiezhong Qiu's avatar
Jiezhong Qiu committed
421
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
422
423
424
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
425
		) {
426
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
427
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
428
429
430
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
431
            
Rick Ho's avatar
Rick Ho committed
432
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
433
434
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
435
#endif
Rick Ho's avatar
Rick Ho committed
436
437
438
439
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
440
    
Rick Ho's avatar
Rick Ho committed
441
442
443
444
445
446
447
448
449
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
450
451
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
452
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
453
454
455
456
457
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
458
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
459
460
461
462
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
463
) {
464
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
465
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
466
467
468
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
469

Rick Ho's avatar
Rick Ho committed
470
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
471
472
473
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
474
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
475

Rick Ho's avatar
Rick Ho committed
476
477
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
478

Rick Ho's avatar
Rick Ho committed
479
480
481
482
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
483
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
484
485
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
486
487
488
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
489
            out_feat,
490
491
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
492
493
494
        );
    }));

Rick Ho's avatar
Rick Ho committed
495
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
496
497
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
498
499

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
500
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
501
502
503
504
505
506
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
507
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
508
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
509
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
510

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
511
512
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
513
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
514
515
516
517
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
518

Jiezhong Qiu's avatar
Jiezhong Qiu committed
519
520
521
522
523
524
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
525
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
526
527
528
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
529
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
530
531
532
533
534
535
536
537
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
538
}
Rick Ho's avatar
Rick Ho committed
539
*/