moe_cuda_kernel.cu 12.5 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
14
#include <mpi.h>

Rick Ho's avatar
Rick Ho committed
15
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
18
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
19
#include "comm_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
25
26
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
29
30
31
32
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

33
34
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
35
void batch_scatter_kernel(size_t wid, const int* pos, 
36
37
38
39
40
41
42
43
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
44
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
45
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
46
47
48
49
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
50
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
51
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
52
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53

Rick Ho's avatar
Rick Ho committed
54
55
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
56

Rick Ho's avatar
Rick Ho committed
57
58
59
60
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
61
	for (int i = 1; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
62
63
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
64

65
66
67
68
69
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
70
	for (int i = num_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
71
72
73
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
74
75
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
76
77
78
	delete [] gate;
	delete [] expert_ptr;
}
79

80
81
#ifdef MOE_USE_NCCL

82
void moe_cuda_global_scatter() {
Rick Ho's avatar
Rick Ho committed
83
84
85
86
87
88
89
	if (cm->size > 1) {
		if (expert_sz) {
			checkCudaErrors(cudaMalloc(&input_buf, 
						sizeof(scalar_t) * expert_sz * in_feat));
			checkCudaErrors(cudaMalloc(&output_buf, 
						sizeof(scalar_t) * expert_sz * out_feat));
		}
Rick Ho's avatar
Rick Ho committed
90
91
		int recv_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
92
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
93
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
				int idx = i + j * num_expert;
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							local_input_buf + expert_ptr[idx] * in_feat, 
							expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
99
100
101
							ncclChar, 
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
102
							h->getStream(0)));
Rick Ho's avatar
Rick Ho committed
103
				}
Rick Ho's avatar
Rick Ho committed
104
105
106
107
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							input_buf + recv_ptr * in_feat,
							all_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
108
109
110
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
111
112
							h->getStream(0)));
					recv_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
113
				}
Rick Ho's avatar
Rick Ho committed
114
			}
Rick Ho's avatar
Rick Ho committed
115
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
116
		}
Rick Ho's avatar
Rick Ho committed
117
118
	} else {
		input_buf = local_input_buf;
Rick Ho's avatar
Rick Ho committed
119
		output_buf = local_output_buf;
Rick Ho's avatar
Rick Ho committed
120
	}
121
}
Rick Ho's avatar
Rick Ho committed
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
void moe_cuda_global_gather() {
	if (cm->size > 1) {
		int send_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
			NCCL_SAFE_CALL(ncclGroupStart());
			for (int j = 0; j < cm->size; ++j) {
				int idx = i + j * num_expert;
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							output_buf + send_ptr * out_feat,
							all_expert_count[idx] * out_feat * sizeof(scalar_t),
							ncclChar,
							j,
							cm->ncclcomm,
							h->getStream(0)));
					send_ptr += all_expert_count[idx];
				}
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							local_output_buf + expert_ptr[idx] * out_feat, 
							expert_count[idx] * out_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0)));
				}
			}
			NCCL_SAFE_CALL(ncclGroupEnd());
		}
	}
}

#endif  // MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
157
158
159
160
161
162
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
		const size_t batch_size,
163
164
		const size_t in_feat, 
		CudaStreamManager* smgr) {
165
	batch_scatter_kernel<scalar_t>
166
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
167
				input_buf); 
168
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
169
}
Rick Ho's avatar
Rick Ho committed
170

Rick Ho's avatar
Rick Ho committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
188
189
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
190
	batch_gather_kernel<scalar_t>
191
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
192
				output); 
193
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
194
}
Rick Ho's avatar
Rick Ho committed
195

Rick Ho's avatar
Rick Ho committed
196
197
198
199
200
201
202
203
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
204
205
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
206
207
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
208
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
209
		if (expert_count[i] == 0) {
Rick Ho's avatar
Rick Ho committed
210
211
212
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
213
214
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
215
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
216
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
217
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
218
				&alpha,
Rick Ho's avatar
Rick Ho committed
219
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
220
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
221
				&beta,
Rick Ho's avatar
Rick Ho committed
222
223
224
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
225
226
		ptr += expert_count[i];
	}
227
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
228
229
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
230
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
231
232
233
234
235
236
237
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
238
239
240
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
241
242
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
243
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
244

Rick Ho's avatar
Rick Ho committed
245
246
247
248
249
250
251
252
253
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
254
255
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
256
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
257
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
258
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
259
				&alpha,
Rick Ho's avatar
Rick Ho committed
260
261
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
262
				&beta,
Rick Ho's avatar
Rick Ho committed
263
264
265
266
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
267
268
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
269
270
271
272
273
274
275
276
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
277
				));
Rick Ho's avatar
Rick Ho committed
278

279
		ptr += expert_count[i];
Rick Ho's avatar
Rick Ho committed
280
	}
281
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
282
}
Rick Ho's avatar
Rick Ho committed
283
284


Rick Ho's avatar
Rick Ho committed
285
std::vector<torch::Tensor> moe_cuda_expert_count(
286
287
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
305
306
}

Rick Ho's avatar
Rick Ho committed
307
308
309
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
310
	auto smgr = getCudaStreamManager(input.device().index());
Rick Ho's avatar
Rick Ho committed
311
312
313
314
315
316
317
318
319
320
321
322
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
323
324
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
325
326
327
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
328

Rick Ho's avatar
Rick Ho committed
329
330
331
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
332
	auto smgr = getCudaStreamManager(output_buf.device().index());
Rick Ho's avatar
Rick Ho committed
333
334
335
336
337
338
339
340
341
342
343
344
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
345
346
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
347
348
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
349
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
350

Jiezhong Qiu's avatar
Jiezhong Qiu committed
351
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
352
353
354
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
355
		) {
356
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
357
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
358
359
360
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
361
            
Rick Ho's avatar
Rick Ho committed
362
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
363
364
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
365
#endif
Rick Ho's avatar
Rick Ho committed
366
367
368
369
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
370
    
Rick Ho's avatar
Rick Ho committed
371
372
373
374
375
376
377
378
379
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
380
381
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
382
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
383
384
385
386
387
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
388
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
389
390
391
392
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
393
) {
394
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
395
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
396
397
398
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
399

Rick Ho's avatar
Rick Ho committed
400
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
401
402
403
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
404
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
405

Rick Ho's avatar
Rick Ho committed
406
407
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
408

Rick Ho's avatar
Rick Ho committed
409
410
411
412
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
413
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
414
415
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
416
417
418
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
419
            out_feat,
420
421
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
422
423
424
        );
    }));

Rick Ho's avatar
Rick Ho committed
425
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
426
427
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
428
429

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
430
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
431
432
433
434
435
436
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
437
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
438
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
439
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
440

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
441
442
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
443
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
444
445
446
447
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
448

Jiezhong Qiu's avatar
Jiezhong Qiu committed
449
450
451
452
453
454
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
455
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
456
457
458
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
459
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
460
461
462
463
464
465
466
467
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
468
}
Rick Ho's avatar
Rick Ho committed
469
*/