moe_cuda_kernel.cu 13.4 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
14
#include <mpi.h>

Rick Ho's avatar
Rick Ho committed
15
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
18
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
19
#include "comm_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Rick Ho's avatar
Rick Ho committed
23
// #define MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
24
#define MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
25
// #define MOE_DEBUG_SCATTER
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
29
30
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
32
33
34
35
36
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
60
61
62
63
64
65
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
	scalar_t v;
	cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
	return v;
}
66

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
68
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
69
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
70
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
71
72
        const scalar_t* weight1,
        const scalar_t* weight2,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
74
75
        const size_t batch_size,
        const size_t in_feat,
Rick Ho's avatar
Rick Ho committed
76
        const size_t hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
78
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79

Rick Ho's avatar
Rick Ho committed
80
    auto h = getCudaStreamManager(num_expert);
Rick Ho's avatar
Rick Ho committed
81
82
	auto cm = getCommManager();
	int tot_expert = num_expert * cm->size;
Rick Ho's avatar
Rick Ho committed
83

Rick Ho's avatar
Rick Ho committed
84
85
86
87
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

Rick Ho's avatar
Rick Ho committed
88
	scalar_t *local_input_buf, *local_output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
89

Rick Ho's avatar
Rick Ho committed
90
91
92
93
	checkCudaErrors(cudaMalloc(&local_input_buf, 
				sizeof(scalar_t) * batch_size * in_feat));
	checkCudaErrors(cudaMalloc(&local_output_buf, 
				sizeof(scalar_t) * batch_size * out_feat));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
94

Rick Ho's avatar
Rick Ho committed
95
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
96
97
	int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
	memset(expert_count, 0, sizeof(int) * tot_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98

Rick Ho's avatar
Rick Ho committed
99
100
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
101

Rick Ho's avatar
Rick Ho committed
102
103
104
105
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
106
	for (int i = 1; i < tot_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
107
108
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
109

110
111
112
113
114
115
116
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
Rick Ho's avatar
Rick Ho committed
117
	for (int i = tot_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
118
119
120
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
121
122
123
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
	int *all_expert_count = new int[tot_expert];
	MPI_Alltoall(expert_count, num_expert, MPI_INT, 
			all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);

	int *expert_n = new int[num_expert];
	int expert_sz = 0;
	for (int i = 0; i < num_expert; ++i) {
		expert_n[i] = 0;
		for (int j = 0; j < cm->size; ++j) {
			expert_n[i] += all_expert_count[j * num_expert + i];
		}
		expert_sz += expert_n[i];
	}

Rick Ho's avatar
Rick Ho committed
138
139
140
141
	scalar_t *input_buf, *hidden_buf, *output_buf;
	if (expert_sz) {
		checkCudaErrors(cudaMalloc(&hidden_buf, 
					sizeof(scalar_t) * expert_sz * hidden_feat));
Rick Ho's avatar
Rick Ho committed
142
143
	}

Rick Ho's avatar
Rick Ho committed
144
145
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
Rick Ho's avatar
Rick Ho committed
146
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_init, t_expert) *
Rick Ho's avatar
Rick Ho committed
147
148
149
			1e6);
#endif

150
151
	batch_scatter_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
Rick Ho's avatar
Rick Ho committed
152
				local_input_buf); 
153
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
154
155
156
157
158
159
160
161
162
	// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));

	if (cm->size > 1) {
		if (expert_sz) {
			checkCudaErrors(cudaMalloc(&input_buf, 
						sizeof(scalar_t) * expert_sz * in_feat));
			checkCudaErrors(cudaMalloc(&output_buf, 
						sizeof(scalar_t) * expert_sz * out_feat));
		}
Rick Ho's avatar
Rick Ho committed
163
164
		int recv_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
165
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
166
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
167
168
169
170
171
				int idx = i + j * num_expert;
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							local_input_buf + expert_ptr[idx] * in_feat, 
							expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
172
173
174
							ncclChar, 
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
175
							h->getStream(0)));
Rick Ho's avatar
Rick Ho committed
176
				}
Rick Ho's avatar
Rick Ho committed
177
178
179
180
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							input_buf + recv_ptr * in_feat,
							all_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
181
182
183
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
184
185
							h->getStream(0)));
					recv_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
186
				}
Rick Ho's avatar
Rick Ho committed
187
			}
Rick Ho's avatar
Rick Ho committed
188
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
189
		}
Rick Ho's avatar
Rick Ho committed
190
191
	} else {
		input_buf = local_input_buf;
Rick Ho's avatar
Rick Ho committed
192
		output_buf = local_output_buf;
Rick Ho's avatar
Rick Ho committed
193
194
	}

Rick Ho's avatar
Rick Ho committed
195
196
	h->sync(0);

Rick Ho's avatar
Rick Ho committed
197
198
199
200
201
202
#ifdef MOE_BREAKDOWN
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
203
204
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
205
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
206
		if (expert_n[i] == 0) {
Rick Ho's avatar
Rick Ho committed
207
208
209
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
Rick Ho's avatar
Rick Ho committed
210
211
		fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
		// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
Rick Ho's avatar
Rick Ho committed
212
213
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
214
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
215
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
216
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
217
				hidden_feat, expert_n[i], in_feat,
Rick Ho's avatar
Rick Ho committed
218
				&alpha,
Rick Ho's avatar
Rick Ho committed
219
				weight1 + i * in_feat * hidden_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
220
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
221
				&beta,
Rick Ho's avatar
Rick Ho committed
222
223
224
225
226
227
				hidden_buf + hidden_feat * ptr, hidden_feat
				));

		checkCudaErrors(cublasXgemm(h->getHandle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
228
				out_feat, expert_n[i], hidden_feat,
Rick Ho's avatar
Rick Ho committed
229
230
231
232
233
				&alpha,
				weight2 + i * hidden_feat * out_feat, hidden_feat,
				hidden_buf + hidden_feat * ptr, hidden_feat,
				&beta,
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
234
				));
Rick Ho's avatar
Rick Ho committed
235

Rick Ho's avatar
Rick Ho committed
236
		ptr += expert_n[i];
Rick Ho's avatar
Rick Ho committed
237
	}
Rick Ho's avatar
Rick Ho committed
238
	h->sync();
Rick Ho's avatar
Rick Ho committed
239
240
241
242
243
244
245

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
246
	if (cm->size > 1) {
Rick Ho's avatar
Rick Ho committed
247
248
		int send_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
249
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
250
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
251
252
253
254
255
				int idx = i + j * num_expert;
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							output_buf + send_ptr * out_feat,
							all_expert_count[idx] * out_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
256
257
258
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
259
260
							h->getStream(0)));
					send_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
261
				}
Rick Ho's avatar
Rick Ho committed
262
263
264
265
266
267
268
269
270
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							local_output_buf + expert_ptr[idx] * out_feat, 
							expert_count[idx] * out_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0)));
				}
Rick Ho's avatar
Rick Ho committed
271
			}
Rick Ho's avatar
Rick Ho committed
272
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
273
274
275
		}
	}

Rick Ho's avatar
Rick Ho committed
276
277
278
279
280
281
#ifdef MOE_BREAKDOWN
	h->sync(0);
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
#endif
282
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
283
284
		<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, 
				local_output_buf, output); 
285
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
286

Rick Ho's avatar
Rick Ho committed
287
#ifdef MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
288
289
	timestamp(t_end);
	fprintf(stderr, "Local gather %.3lf us\n", getDuration(t_gather, t_end) *
Rick Ho's avatar
Rick Ho committed
290
			1e6);
Rick Ho's avatar
Rick Ho committed
291
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_end) *
Rick Ho's avatar
Rick Ho committed
292
293
294
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
295
296
297
298
299
300
	if (expert_sz) {
		cudaFree(hidden_buf);
		if (cm->size > 1) {
			cudaFree(input_buf);
			cudaFree(output_buf);
		}
Rick Ho's avatar
Rick Ho committed
301
	}
Rick Ho's avatar
Rick Ho committed
302
303
	cudaFree(local_input_buf);
	cudaFree(local_output_buf);
304
305
306
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
307
308
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
309
310
311
312
313
314
315
316
317
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
318
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
319

Rick Ho's avatar
Rick Ho committed
320
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
321
322
323
324
325
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
326
327
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
328
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
329
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
330
331
332
333
334
335
336
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
337
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
338
339
340
341
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
342
343
344
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
345
346
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
347

Jiezhong Qiu's avatar
Jiezhong Qiu committed
348
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
349
350
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
351
352
353
        torch::Tensor weight1,
        torch::Tensor weight2
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
354
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
355
356
357
358
    const auto num_expert = weight1.size(0);
    const auto out_feat = weight2.size(1);
	const auto hidden_feat = weight1.size(1);
    const auto in_feat = weight1.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
359
            
Rick Ho's avatar
Rick Ho committed
360
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
361
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
362
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
363
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
364
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
365
366
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
367
368
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
369
370
                    weight1.data_ptr<scalar_t>(),
                    weight2.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
371
372
373
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Rick Ho's avatar
Rick Ho committed
374
					hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
375
                    out_feat,
Rick Ho's avatar
Rick Ho committed
376
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
377
378
379
380
381
382
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
383
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
384
385
386
387
388
389
390
391
392
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
393
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
394
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
395
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
396
397
398

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
399
400

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
401
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
402
403
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
404
405
406
407
408
409
410
411
412
413
414
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
415
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
416
417
418
419
420
421
422
423
424

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
425
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
426
427
428
429
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
430
431
432
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
433
434

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
435
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
436
437
438
439
440
441
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
442
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
443
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
444
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
445

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
446
447
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
448
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
449
450
451
452
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
453

Jiezhong Qiu's avatar
Jiezhong Qiu committed
454
455
456
457
458
459
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
460
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
461
462
463
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
464
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
465
466
467
468
469
470
471
472
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
473
}
Rick Ho's avatar
Rick Ho committed
474
*/