moe_cuda_kernel.cu 13.3 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
10
#include <cublas_v2.h>
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
11
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
14
#include <mpi.h>

Rick Ho's avatar
Rick Ho committed
15
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
18
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
19
#include "comm_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Rick Ho's avatar
Rick Ho committed
23
// #define MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
24
#define MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
25
// #define MOE_DEBUG_SCATTER
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
29
30
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
32
33
34
35
36
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

Rick Ho's avatar
Rick Ho committed
60
61
62
63
64
65
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
	scalar_t v;
	cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
	return v;
}
66

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
68
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
69
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
70
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
71
72
        const scalar_t* weight1,
        const scalar_t* weight2,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
74
75
        const size_t batch_size,
        const size_t in_feat,
Rick Ho's avatar
Rick Ho committed
76
        const size_t hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
78
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79

Rick Ho's avatar
Rick Ho committed
80
    auto h = getCudaStreamManager(num_expert);
Rick Ho's avatar
Rick Ho committed
81
82
	auto cm = getCommManager();
	int tot_expert = num_expert * cm->size;
Rick Ho's avatar
Rick Ho committed
83

Rick Ho's avatar
Rick Ho committed
84
85
86
87
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

Rick Ho's avatar
Rick Ho committed
88
	scalar_t *local_input_buf, *local_output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
89

Rick Ho's avatar
Rick Ho committed
90
91
92
93
	checkCudaErrors(cudaMalloc(&local_input_buf, 
				sizeof(scalar_t) * batch_size * in_feat));
	checkCudaErrors(cudaMalloc(&local_output_buf, 
				sizeof(scalar_t) * batch_size * out_feat));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
94

Rick Ho's avatar
Rick Ho committed
95
    int *gate = new int[batch_size];
Rick Ho's avatar
Rick Ho committed
96
97
	int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
	memset(expert_count, 0, sizeof(int) * tot_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98

Rick Ho's avatar
Rick Ho committed
99
100
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
101

Rick Ho's avatar
Rick Ho committed
102
103
104
105
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
106
	for (int i = 1; i < tot_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
107
108
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
109

110
111
112
113
114
115
116
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
Rick Ho's avatar
Rick Ho committed
117
	for (int i = tot_expert - 1; i > 0; --i) {
Rick Ho's avatar
Rick Ho committed
118
119
120
		expert_ptr[i] = expert_ptr[i - 1];
	}
	expert_ptr[0] = 0;
121
122
123
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
	int *all_expert_count = new int[tot_expert];
	MPI_Alltoall(expert_count, num_expert, MPI_INT, 
			all_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);

	int *expert_n = new int[num_expert];
	int expert_sz = 0;
	for (int i = 0; i < num_expert; ++i) {
		expert_n[i] = 0;
		for (int j = 0; j < cm->size; ++j) {
			expert_n[i] += all_expert_count[j * num_expert + i];
		}
		expert_sz += expert_n[i];
	}

Rick Ho's avatar
Rick Ho committed
138
139
140
141
	scalar_t *input_buf, *hidden_buf, *output_buf;
	if (expert_sz) {
		checkCudaErrors(cudaMalloc(&hidden_buf, 
					sizeof(scalar_t) * expert_sz * hidden_feat));
Rick Ho's avatar
Rick Ho committed
142
143
	}

Rick Ho's avatar
Rick Ho committed
144
145
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
Rick Ho's avatar
Rick Ho committed
146
147
148
	fprintf(stderr, "Expert asn %d time %.3lf us\n", 
			expert_sz,
			getDuration(t_init, t_expert) * 1e6);
Rick Ho's avatar
Rick Ho committed
149
150
#endif

151
152
	batch_scatter_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
Rick Ho's avatar
Rick Ho committed
153
				local_input_buf); 
154
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
155
156
157
158
159
160
161
162
163
	// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));

	if (cm->size > 1) {
		if (expert_sz) {
			checkCudaErrors(cudaMalloc(&input_buf, 
						sizeof(scalar_t) * expert_sz * in_feat));
			checkCudaErrors(cudaMalloc(&output_buf, 
						sizeof(scalar_t) * expert_sz * out_feat));
		}
Rick Ho's avatar
Rick Ho committed
164
165
		int recv_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
166
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
167
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
168
169
170
171
172
				int idx = i + j * num_expert;
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							local_input_buf + expert_ptr[idx] * in_feat, 
							expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
173
174
175
							ncclChar, 
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
176
							h->getStream(0)));
Rick Ho's avatar
Rick Ho committed
177
				}
Rick Ho's avatar
Rick Ho committed
178
179
180
181
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							input_buf + recv_ptr * in_feat,
							all_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
182
183
184
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
185
186
							h->getStream(0)));
					recv_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
187
				}
Rick Ho's avatar
Rick Ho committed
188
			}
Rick Ho's avatar
Rick Ho committed
189
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
190
		}
Rick Ho's avatar
Rick Ho committed
191
192
	} else {
		input_buf = local_input_buf;
Rick Ho's avatar
Rick Ho committed
193
		output_buf = local_output_buf;
Rick Ho's avatar
Rick Ho committed
194
195
	}

Rick Ho's avatar
Rick Ho committed
196
197
	h->sync(0);

Rick Ho's avatar
Rick Ho committed
198
199
200
201
202
203
#ifdef MOE_BREAKDOWN
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
204
205
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
206
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
207
		if (expert_n[i] == 0) {
Rick Ho's avatar
Rick Ho committed
208
209
210
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
Rick Ho's avatar
Rick Ho committed
211
212
		fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
		// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
Rick Ho's avatar
Rick Ho committed
213
214
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
215
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
216
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
217
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
218
				hidden_feat, expert_n[i], in_feat,
Rick Ho's avatar
Rick Ho committed
219
				&alpha,
Rick Ho's avatar
Rick Ho committed
220
				weight1 + i * in_feat * hidden_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
221
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
222
				&beta,
Rick Ho's avatar
Rick Ho committed
223
224
225
226
227
228
				hidden_buf + hidden_feat * ptr, hidden_feat
				));

		checkCudaErrors(cublasXgemm(h->getHandle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
229
				out_feat, expert_n[i], hidden_feat,
Rick Ho's avatar
Rick Ho committed
230
231
232
233
234
				&alpha,
				weight2 + i * hidden_feat * out_feat, hidden_feat,
				hidden_buf + hidden_feat * ptr, hidden_feat,
				&beta,
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
235
				));
Rick Ho's avatar
Rick Ho committed
236

Rick Ho's avatar
Rick Ho committed
237
		ptr += expert_n[i];
Rick Ho's avatar
Rick Ho committed
238
	}
Rick Ho's avatar
Rick Ho committed
239
	h->sync();
Rick Ho's avatar
Rick Ho committed
240
241
242
243
244
245
246

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
247
	if (cm->size > 1) {
Rick Ho's avatar
Rick Ho committed
248
249
		int send_ptr = 0;
		for (int i = 0; i < num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
250
			NCCL_SAFE_CALL(ncclGroupStart());
Rick Ho's avatar
Rick Ho committed
251
			for (int j = 0; j < cm->size; ++j) {
Rick Ho's avatar
Rick Ho committed
252
253
254
255
256
				int idx = i + j * num_expert;
				if (all_expert_count[idx]) {
					NCCL_SAFE_CALL(ncclSend(
							output_buf + send_ptr * out_feat,
							all_expert_count[idx] * out_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
257
258
259
							ncclChar,
							j,
							cm->ncclcomm,
Rick Ho's avatar
Rick Ho committed
260
261
							h->getStream(0)));
					send_ptr += all_expert_count[idx];
Rick Ho's avatar
Rick Ho committed
262
				}
Rick Ho's avatar
Rick Ho committed
263
264
265
266
267
268
269
270
271
				if (expert_count[idx]) {
					NCCL_SAFE_CALL(ncclRecv(
							local_output_buf + expert_ptr[idx] * out_feat, 
							expert_count[idx] * out_feat * sizeof(scalar_t),
							ncclChar, 
							j,
							cm->ncclcomm,
							h->getStream(0)));
				}
Rick Ho's avatar
Rick Ho committed
272
			}
Rick Ho's avatar
Rick Ho committed
273
			NCCL_SAFE_CALL(ncclGroupEnd());
Rick Ho's avatar
Rick Ho committed
274
275
276
		}
	}

Rick Ho's avatar
Rick Ho committed
277
278
279
280
281
282
#ifdef MOE_BREAKDOWN
	h->sync(0);
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
#endif
283
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
284
285
		<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, 
				local_output_buf, output); 
286
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
287

Rick Ho's avatar
Rick Ho committed
288
#ifdef MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
289
290
	timestamp(t_end);
	fprintf(stderr, "Local gather %.3lf us\n", getDuration(t_gather, t_end) *
Rick Ho's avatar
Rick Ho committed
291
			1e6);
Rick Ho's avatar
Rick Ho committed
292
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_end) *
Rick Ho's avatar
Rick Ho committed
293
294
295
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
296
297
298
299
300
301
	if (expert_sz) {
		cudaFree(hidden_buf);
		if (cm->size > 1) {
			cudaFree(input_buf);
			cudaFree(output_buf);
		}
Rick Ho's avatar
Rick Ho committed
302
	}
Rick Ho's avatar
Rick Ho committed
303
304
	cudaFree(local_input_buf);
	cudaFree(local_output_buf);
305
306
307
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
308
309
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
310
311
312
313
314
315
316
317
318
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
319
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
320

Rick Ho's avatar
Rick Ho committed
321
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
322
323
324
325
326
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
327
328
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
329
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
330
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
331
332
333
334
335
336
337
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
338
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
339
340
341
342
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
343
344
345
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
346
347
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
348

Jiezhong Qiu's avatar
Jiezhong Qiu committed
349
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
350
351
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
352
353
354
        torch::Tensor weight1,
        torch::Tensor weight2
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
355
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
356
357
358
359
    const auto num_expert = weight1.size(0);
    const auto out_feat = weight2.size(1);
	const auto hidden_feat = weight1.size(1);
    const auto in_feat = weight1.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
360
            
Rick Ho's avatar
Rick Ho committed
361
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
362
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
363
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
364
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
365
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
366
367
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
368
369
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
370
371
                    weight1.data_ptr<scalar_t>(),
                    weight2.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
372
373
374
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Rick Ho's avatar
Rick Ho committed
375
					hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
376
                    out_feat,
Rick Ho's avatar
Rick Ho committed
377
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
378
379
380
381
382
383
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
384
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
385
386
387
388
389
390
391
392
393
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
394
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
395
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
396
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
397
398
399

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
400
401

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
402
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
403
404
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
405
406
407
408
409
410
411
412
413
414
415
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
416
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
417
418
419
420
421
422
423
424
425

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
426
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
427
428
429
430
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
431
432
433
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
434
435

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
436
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
437
438
439
440
441
442
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
443
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
444
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
445
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
446

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
447
448
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
449
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
450
451
452
453
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
454

Jiezhong Qiu's avatar
Jiezhong Qiu committed
455
456
457
458
459
460
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
461
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
462
463
464
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
465
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
466
467
468
469
470
471
472
473
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
474
}
Rick Ho's avatar
Rick Ho committed
475
*/