moe_cuda_kernel.cu 13.2 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
14
#include <mpi.h>

Rick Ho's avatar
Rick Ho committed
15
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
18
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
19
#include "comm_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Rick Ho's avatar
Rick Ho committed
23
// #define MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
24
// #define MOE_DEBUG_SCATTER
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
61
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
62
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
63
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
64
65
        const scalar_t* weight1,
        const scalar_t* weight2,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
66
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
67
68
        const size_t batch_size,
        const size_t in_feat,
Rick Ho's avatar
Rick Ho committed
69
        const size_t hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
70
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
71
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
72

Rick Ho's avatar
Rick Ho committed
73
    auto h = getCudaStreamManager(num_expert);
Rick Ho's avatar
Rick Ho committed
74
75
	auto cm = getCommManager();
	int tot_expert = num_expert * cm->size;
Rick Ho's avatar
Rick Ho committed
76

Rick Ho's avatar
Rick Ho committed
77
78
79
80
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

Rick Ho's avatar
Rick Ho committed
81
	scalar_t *local_input_buf, *local_output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82

Rick Ho's avatar
Rick Ho committed
83
	checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
Rick Ho's avatar
Rick Ho committed
84
				in_feat));
Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
90

#ifdef MOE_BREAKDOWN
	timestamp(t_malloc);
	fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
			1e6);
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
91

Rick Ho's avatar
Rick Ho committed
92
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
93
94
	int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
	memset(expert_count, 0, sizeof(int) * tot_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
95

Rick Ho's avatar
Rick Ho committed
96
97
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
98
99
100
101
102
103
104

#ifdef MOE_BREAKDOWN
	timestamp(t_cpy);
	fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
105
106
107
108
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
109
	for (int i = 1; i < tot_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
110
111
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
112

113
114
115
116
117
118
119
120
121
122
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
	int *all_expert_count = new int[tot_expert];
	MPI_Alltoall(expert_count, num_expert, MPI_INT, 
			all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);

	int *expert_n = new int[num_expert];
	int expert_sz = 0;
	for (int i = 0; i < num_expert; ++i) {
		expert_n[i] = 0;
		for (int j = 0; j < cm->size; ++j) {
			expert_n[i] += all_expert_count[j * num_expert + i];
		}
		expert_sz += expert_n[i];
	}
	scalar_t *input_buf, *hidden_buf, *output_buf;
	checkCudaErrors(cudaMalloc(&hidden_buf, 
				sizeof(scalar_t) * expert_sz * hidden_feat));

#ifdef MOE_DEBUG
	for (int i = 0; i < tot_expert; ++i) {
		fprintf(stderr, "%d %d %d\n", cm->rank, i, expert_count[i]);
	}
	if (cm->rank == 0) {
		for (int i = 0; i < tot_expert; ++i) {
			fprintf(stderr, "%d ",all_expert_count[i]);
		}
		fprintf(stderr, "\n");
	}
#endif

Rick Ho's avatar
Rick Ho committed
152
153
154
155
156
157
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
			1e6);
#endif

158
159
	batch_scatter_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
Rick Ho's avatar
Rick Ho committed
160
				local_input_buf); 
161
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
162

Rick Ho's avatar
Rick Ho committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
	if (cm->rank > 1) {
		checkCudaErrors(cudaMalloc(&input_buf, 
					sizeof(scalar_t) * expert_sz * in_feat));
		checkCudaErrors(cudaMalloc(&output_buf, 
					sizeof(scalar_t) * expert_sz * out_feat));
		ncclGroupStart();
		int recv_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
			for (int j = 0; j < cm->size; ++j) {
				int send_id = i + j * num_expert;
				if (expert_count[send_id]) {
					ncclSend(local_input_buf + expert_ptr[send_id] * in_feat, 
							expert_count[send_id] * in_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0));
				}
				int recv_id = i * cm->size + j;
				if (all_expert_count[recv_id]) {
					ncclRecv(input_buf + recv_ptr * in_feat,
							all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
							ncclChar,
							j,
							cm->ncclcomm,
							h->getStream(0));
					recv_ptr += all_expert_count[recv_id];
				}
Rick Ho's avatar
Rick Ho committed
191
192
			}
		}
Rick Ho's avatar
Rick Ho committed
193
194
195
		ncclGroupEnd();
	} else {
		input_buf = local_input_buf;
Rick Ho's avatar
Rick Ho committed
196
197
	}

Rick Ho's avatar
Rick Ho committed
198
199
200
201
202
203
204
#ifdef MOE_BREAKDOWN
	h->sync();
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
205
206
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
207
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
208
		if (expert_n[i] == 0) {
Rick Ho's avatar
Rick Ho committed
209
210
211
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
Rick Ho's avatar
Rick Ho committed
212
213
		fprintf(stderr, "gemm %d sz %d\n", i, expert_n[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_n[i],
Rick Ho's avatar
Rick Ho committed
214
215
216
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
217
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
218
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
219
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
220
				hidden_feat, expert_n[i], in_feat,
Rick Ho's avatar
Rick Ho committed
221
				&alpha,
Rick Ho's avatar
Rick Ho committed
222
				weight1 + i * in_feat * hidden_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
223
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
224
				&beta,
Rick Ho's avatar
Rick Ho committed
225
226
227
228
229
230
				hidden_buf + hidden_feat * ptr, hidden_feat
				));

		checkCudaErrors(cublasXgemm(h->getHandle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
231
				out_feat, expert_n[i], hidden_feat,
Rick Ho's avatar
Rick Ho committed
232
233
234
235
236
				&alpha,
				weight2 + i * hidden_feat * out_feat, hidden_feat,
				hidden_buf + hidden_feat * ptr, hidden_feat,
				&beta,
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
237
				));
Rick Ho's avatar
Rick Ho committed
238

Rick Ho's avatar
Rick Ho committed
239
		ptr += expert_n[i];
Rick Ho's avatar
Rick Ho committed
240
	}
Rick Ho's avatar
Rick Ho committed
241
	h->sync();
Rick Ho's avatar
Rick Ho committed
242
243
244
245
246
247
248

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
	if (cm->rank > 1) {
		checkCudaErrors(cudaMalloc(&local_output_buf, 
					sizeof(scalar_t) * batch_size * out_feat));
		ncclGroupStart();
		int send_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
			for (int j = 0; j < cm->size; ++j) {
				int recv_id = i + j * num_expert;
				if (expert_count[recv_id]) {
					ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat, 
							expert_count[recv_id] * in_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0));
				}
				int send_id = i * cm->size + j;
				if (all_expert_count[send_id]) {
					ncclSend(output_buf + send_ptr * in_feat,
							all_expert_count[send_id] * in_feat * sizeof(scalar_t),
							ncclChar,
							j,
							cm->ncclcomm,
							h->getStream(0));
					send_ptr += all_expert_count[send_id];
				}
Rick Ho's avatar
Rick Ho committed
275
276
			}
		}
Rick Ho's avatar
Rick Ho committed
277
278
279
		ncclGroupEnd();
	} else {
		local_output_buf = output_buf;
Rick Ho's avatar
Rick Ho committed
280
281
	}

282
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
283
284
		<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, 
				local_output_buf, output); 
285
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
286

Rick Ho's avatar
Rick Ho committed
287
288
289
290
291
292
293
294
#ifdef MOE_BREAKDOWN
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
295
	cudaFree(input_buf);
296
	cudaFree(hidden_buf);
Rick Ho's avatar
Rick Ho committed
297
	cudaFree(output_buf);
Rick Ho's avatar
Rick Ho committed
298
299
300
301
	if (cm->rank > 1) {
		cudaFree(local_input_buf);
		cudaFree(local_output_buf);
	}
302
303
304
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
305
306
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
307
308
309
310
311
312
313
314
315
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
316
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
317

Rick Ho's avatar
Rick Ho committed
318
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
319
320
321
322
323
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
324
325
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
326
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
327
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
328
329
330
331
332
333
334
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
335
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
336
337
338
339
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
340
341
342
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
343
344
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
345

Jiezhong Qiu's avatar
Jiezhong Qiu committed
346
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
347
348
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
349
350
351
        torch::Tensor weight1,
        torch::Tensor weight2
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
352
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
353
354
355
356
    const auto num_expert = weight1.size(0);
    const auto out_feat = weight2.size(1);
	const auto hidden_feat = weight1.size(1);
    const auto in_feat = weight1.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
357
            
Rick Ho's avatar
Rick Ho committed
358
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
359
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
360
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
361
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
362
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
363
364
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
365
366
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
367
368
                    weight1.data_ptr<scalar_t>(),
                    weight2.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
369
370
371
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Rick Ho's avatar
Rick Ho committed
372
					hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
373
                    out_feat,
Rick Ho's avatar
Rick Ho committed
374
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
375
376
377
378
379
380
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
381
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
382
383
384
385
386
387
388
389
390
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
391
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
392
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
393
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
394
395
396

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
397
398

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
399
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
400
401
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
402
403
404
405
406
407
408
409
410
411
412
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
413
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
414
415
416
417
418
419
420
421
422

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
423
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
424
425
426
427
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
428
429
430
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
431
432

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
433
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
434
435
436
437
438
439
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
440
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
441
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
442
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
443

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
444
445
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
446
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
447
448
449
450
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
451

Jiezhong Qiu's avatar
Jiezhong Qiu committed
452
453
454
455
456
457
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
458
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
459
460
461
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
462
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
463
464
465
466
467
468
469
470
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
471
}
Rick Ho's avatar
Rick Ho committed
472
*/