moe_cuda_kernel.cu 12.4 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
10
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
14
#include <mpi.h>

Rick Ho's avatar
Rick Ho committed
15
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
18
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
19
#include "comm_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
25
26
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
29
30
31
32
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

33
34
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
35
void batch_scatter_kernel(size_t wid, const int* pos, 
36
37
38
39
40
41
42
43
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
44
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
45
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
46
47
48
49
		int* expert_count,
		int* d_pos,
		const size_t num_expert,
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
50
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
51
	int *expert_ptr = new int[num_expert];
Rick Ho's avatar
Rick Ho committed
52
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53

Rick Ho's avatar
Rick Ho committed
54
55
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
56

Rick Ho's avatar
Rick Ho committed
57
58
59
60
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
61
	for (int i = 1; i < tot_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
62
63
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
64

65
66
67
68
69
	int *pos = new int[batch_size];

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
Rick Ho's avatar
Rick Ho committed
70
	for (int i = tot_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
71
72
73
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
74
75
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));
Rick Ho's avatar
Rick Ho committed
76
77
78
	delete [] gate;
	delete [] expert_ptr;
}
79

80
void moe_cuda_global_scatter() {
Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
87
	if (cm->size > 1) {
		if (expert_sz) {
			checkCudaErrors(cudaMalloc(&input_buf, 
						sizeof(scalar_t) * expert_sz * in_feat));
			checkCudaErrors(cudaMalloc(&output_buf, 
						sizeof(scalar_t) * expert_sz * out_feat));
		}
Rick Ho's avatar
Rick Ho committed
88
89
		int recv_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
90
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
91
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
92
93
94
95
96
				int idx = i + j * num_expert;
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							local_input_buf + expert_ptr[idx] * in_feat, 
							expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
97
98
99
							ncclChar, 
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
100
							h->getStream(0)));
Rick Ho's avatar
Rick Ho committed
101
				}
Rick Ho's avatar
Rick Ho committed
102
103
104
105
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							input_buf + recv_ptr * in_feat,
							all_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
106
107
108
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
109
110
							h->getStream(0)));
					recv_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
111
				}
Rick Ho's avatar
Rick Ho committed
112
			}
Rick Ho's avatar
Rick Ho committed
113
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
114
		}
Rick Ho's avatar
Rick Ho committed
115
116
	} else {
		input_buf = local_input_buf;
Rick Ho's avatar
Rick Ho committed
117
		output_buf = local_output_buf;
Rick Ho's avatar
Rick Ho committed
118
	}
119
}
Rick Ho's avatar
Rick Ho committed
120

Rick Ho's avatar
Rick Ho committed
121
122
123
124
125
126
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
		const int* d_pos,
		scalar_t* input_buf,
		const size_t batch_size,
127
128
		const size_t in_feat, 
		CudaStreamManager* smgr) {
129
	batch_scatter_kernel<scalar_t>
130
		<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
131
				input_buf); 
132
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
133
}
Rick Ho's avatar
Rick Ho committed
134

Rick Ho's avatar
Rick Ho committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
		const int* d_pos,
		scalar_t* output,
		const size_t batch_size,
152
153
		const size_t out_feat,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
154
	batch_gather_kernel<scalar_t>
155
		<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
Rick Ho's avatar
Rick Ho committed
156
				output); 
157
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
158
}
Rick Ho's avatar
Rick Ho committed
159

Rick Ho's avatar
Rick Ho committed
160
161
162
163
164
165
166
167
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
		const int* expert_count,
        scalar_t* output_buf,
        const size_t in_feat,
        const size_t out_feat,
168
169
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
170
171
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
172
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
173
		if (expert_n[i] == 0) {
Rick Ho's avatar
Rick Ho committed
174
175
176
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C
177
178
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
179
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
180
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
181
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
182
				&alpha,
Rick Ho's avatar
Rick Ho committed
183
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
184
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
185
				&beta,
Rick Ho's avatar
Rick Ho committed
186
187
188
				output_buf + out_feat * ptr, out_feat
				));

Rick Ho's avatar
Rick Ho committed
189
190
		ptr += expert_count[i];
	}
191
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
192
193
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
194
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
195
196
197
198
199
200
201
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
		const scalar_t* weight,
		const int* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
202
203
204
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
205
206
        const size_t num_expert,
		CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
207
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
208

Rick Ho's avatar
Rick Ho committed
209
210
211
212
213
214
215
216
217
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
					sizeof(scalar_t) * in_feat * out_feat);
			continue;
		}
		// Use T(B) x T(A) = T(C) to produce row-major C

		// Backward input: g_i = w @ g_o
218
219
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
220
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
221
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
222
				in_feat, expert_count[i], out_feat,
Rick Ho's avatar
Rick Ho committed
223
				&alpha,
Rick Ho's avatar
Rick Ho committed
224
225
				weight + i * in_feat * out_feat, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
Rick Ho's avatar
Rick Ho committed
226
				&beta,
Rick Ho's avatar
Rick Ho committed
227
228
229
230
				grad_input_buf + in_feat * ptr, in_feat
				));

		// Backward weight: g_w = i @ g_o
231
232
		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
Rick Ho's avatar
Rick Ho committed
233
234
235
236
237
238
239
240
				CUBLAS_OP_N,
				CUBLAS_OP_T,
				in_feat, out_feat, expert_count[i],
				&alpha,
				input_buf + in_feat * ptr, in_feat,
				grad_output_buf + ptr * out_feat, out_feat,
				&beta,
				grad_weight + i * in_feat * out_feat, in_feat
Rick Ho's avatar
Rick Ho committed
241
				));
Rick Ho's avatar
Rick Ho committed
242

Rick Ho's avatar
Rick Ho committed
243
		ptr += expert_n[i];
Rick Ho's avatar
Rick Ho committed
244
	}
245
	smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
246
}
Rick Ho's avatar
Rick Ho committed
247
248


249
void moe_cuda_global_gather() {
Rick Ho's avatar
Rick Ho committed
250
	if (cm->size > 1) {
Rick Ho's avatar
Rick Ho committed
251
252
		int send_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
253
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
254
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
255
256
257
258
259
				int idx = i + j * num_expert;
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							output_buf + send_ptr * out_feat,
							all_expert_count[idx] * out_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
260
261
262
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
263
264
							h->getStream(0)));
					send_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
265
				}
Rick Ho's avatar
Rick Ho committed
266
267
268
269
270
271
272
273
274
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							local_output_buf + expert_ptr[idx] * out_feat, 
							expert_count[idx] * out_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0)));
				}
Rick Ho's avatar
Rick Ho committed
275
			}
Rick Ho's avatar
Rick Ho committed
276
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
277
278
		}
	}
279
}
Rick Ho's avatar
Rick Ho committed
280

Rick Ho's avatar
Rick Ho committed
281
std::vector<torch::Tensor> moe_cuda_expert_count(
282
283
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
	const auto batch_size = gate.size(0);

	auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
	auto expert_count = torch::empty(num_expert, ec_options);

	auto pos_options = torch::TensorOptions()
		.device(gate.device())
		.dtype(torch::kInt32);
	auto pos = torch::empty(batch_size, pos_options);
	moe_cuda_expert_count_impl(
			gate.data_ptr<int>(),
			expert_count.data_ptr<int>(),
			pos.data_ptr<int>(),
			num_expert,
			batch_size);

	return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
301
302
}

Rick Ho's avatar
Rick Ho committed
303
304
305
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
	torch::Tensor pos) {
306
	auto smgr = getCudaStreamManager(input.device().index());
Rick Ho's avatar
Rick Ho committed
307
308
309
310
311
312
313
314
315
316
317
318
	const auto batch_size = input.size(0);
    const auto in_feat = input.size(1);

	auto input_buf = torch::empty_like(input);

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", 
			([&] {
		moe_cuda_local_scatter_impl<scalar_t>(
			input.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			input_buf.data_ptr<scalar_t>(),
			batch_size,
319
320
			in_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
321
322
323
	}));
	return {input_buf,};
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
324

Rick Ho's avatar
Rick Ho committed
325
326
327
std::vector<torch::Tensor> moe_cuda_local_gather(
	torch::Tensor output_buf,
	torch::Tensor pos) {
328
	auto smgr = getCudaStreamManager(output_buf.device().index());
Rick Ho's avatar
Rick Ho committed
329
330
331
332
333
334
335
336
337
338
339
340
	const auto batch_size = output_buf.size(0);
    const auto out_feat = output_buf.size(1);

	auto output = torch::empty_like(output_buf);

    AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", 
			([&] {
		moe_cuda_local_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
			pos.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			batch_size,
341
342
			out_feat,
			smgr);
Rick Ho's avatar
Rick Ho committed
343
344
	}));
	return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
345
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
346

Jiezhong Qiu's avatar
Jiezhong Qiu committed
347
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
348
349
350
        torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor expert_count
Rick Ho's avatar
Rick Ho committed
351
		) {
352
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
353
	const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
354
355
356
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
357
            
Rick Ho's avatar
Rick Ho committed
358
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
359
360
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
			num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
361
#endif
Rick Ho's avatar
Rick Ho committed
362
363
364
365
	auto out_options = torch::TensorOptions()
		.device(input_buf.device())
		.dtype(input_buf.dtype());
    auto output = torch::empty({batch_size, out_feat}, out_options);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
366
    
Rick Ho's avatar
Rick Ho committed
367
368
369
370
371
372
373
374
375
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", 
			([&] {
		moe_cuda_forward_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			expert_count.data_ptr<int>(),
			output.data_ptr<scalar_t>(),
			in_feat,
			out_feat,
376
377
			num_expert,
			smgr
Rick Ho's avatar
Rick Ho committed
378
		);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
379
380
381
382
383
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
384
std::vector<torch::Tensor> moe_cuda_backward(
Rick Ho's avatar
Rick Ho committed
385
386
387
388
    torch::Tensor grad_output_buf, // [batch_size x out_feat]
    torch::Tensor input_buf, // [batch_size x out_feat]
    torch::Tensor weight, // [num_expert x out_feat x in_feat]
	torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
389
) {
390
	auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
391
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
392
393
394
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
395

Rick Ho's avatar
Rick Ho committed
396
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
397
398
399
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
			"out_feat (d_ffn)=%ld\n",
			batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
400
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
401

Rick Ho's avatar
Rick Ho committed
402
403
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
404

Rick Ho's avatar
Rick Ho committed
405
406
407
408
    AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
409
            weight.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
410
411
			expert_count.data_ptr<int>(),
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
412
413
414
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
415
            out_feat,
416
417
            num_expert,
			smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
418
419
420
        );
    }));

Rick Ho's avatar
Rick Ho committed
421
    return {grad_input_buf, grad_weight};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
422
423
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
424
425

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
426
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
427
428
429
430
431
432
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
433
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
434
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
435
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
436

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
437
438
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
439
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
440
441
442
443
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
444

Jiezhong Qiu's avatar
Jiezhong Qiu committed
445
446
447
448
449
450
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
451
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
452
453
454
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
455
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
456
457
458
459
460
461
462
463
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
464
}
Rick Ho's avatar
Rick Ho committed
465
*/