moe_cuda_kernel.cu 10.6 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

Rick Ho's avatar
Rick Ho committed
20
// #define MOE_BREAKDOWN
21
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
25
26
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
29
30
31
32
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
60
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
61
62
        const scalar_t* weight1,
        const scalar_t* weight2,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
63
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
64
65
        const size_t batch_size,
        const size_t in_feat,
Rick Ho's avatar
Rick Ho committed
66
        const size_t hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
68
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
69

Rick Ho's avatar
Rick Ho committed
70
71
    auto h = getCudaStreamManager(num_expert);

Rick Ho's avatar
Rick Ho committed
72
73
74
75
76
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

	scalar_t *input_buf, *hidden_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77

Rick Ho's avatar
Rick Ho committed
78
79
80
81
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Rick Ho's avatar
Rick Ho committed
82
83
84
85
86
87
88
89
	checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
				hidden_feat));

#ifdef MOE_BREAKDOWN
	timestamp(t_malloc);
	fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
			1e6);
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
90

Rick Ho's avatar
Rick Ho committed
91
92
93
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
94

Rick Ho's avatar
Rick Ho committed
95
96
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
102
103

#ifdef MOE_BREAKDOWN
	timestamp(t_cpy);
	fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
104
105
106
107
108
109
110
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
111

112
113
114
115
116
117
118
119
120
121
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
122
123
124
125
126
127
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
			1e6);
#endif

128
129
130
131
	batch_scatter_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
				input_buf); 
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
132

Rick Ho's avatar
Rick Ho committed
133
134
135
136
137
138
139
#ifdef MOE_BREAKDOWN
	h->sync();
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
140
141
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
142
143
144
145
146
147
148
149
150
151
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
152
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
153
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
154
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
155
				hidden_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
156
				&alpha,
Rick Ho's avatar
Rick Ho committed
157
				weight1 + i * in_feat * hidden_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
158
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
159
				&beta,
Rick Ho's avatar
Rick Ho committed
160
161
162
163
164
165
166
167
168
169
170
171
				hidden_buf + hidden_feat * ptr, hidden_feat
				));

		checkCudaErrors(cublasXgemm(h->getHandle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
				out_feat, expert_count[i], hidden_feat,
				&alpha,
				weight2 + i * hidden_feat * out_feat, hidden_feat,
				hidden_buf + hidden_feat * ptr, hidden_feat,
				&beta,
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
172
				));
Rick Ho's avatar
Rick Ho committed
173

Rick Ho's avatar
Rick Ho committed
174
175
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
176
177
178
179
180
181
182

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
183
	h->sync();
184
185
186
187
	batch_gather_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
				output); 
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
188

Rick Ho's avatar
Rick Ho committed
189
190
191
192
193
194
195
196
#ifdef MOE_BREAKDOWN
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
197
	cudaFree(input_buf);
198
	cudaFree(hidden_buf);
Rick Ho's avatar
Rick Ho committed
199
	cudaFree(output_buf);
200
201
202
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
203
204
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
205
206
207
208
209
210
211
212
213
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
214
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
215

Rick Ho's avatar
Rick Ho committed
216
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
217
218
219
220
221
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
222
223
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
224
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
225
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
226
227
228
229
230
231
232
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
233
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
234
235
236
237
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
238
239
240
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
241
242
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
243

Jiezhong Qiu's avatar
Jiezhong Qiu committed
244
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
245
246
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
247
248
249
        torch::Tensor weight1,
        torch::Tensor weight2
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
250
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
251
252
253
254
    const auto num_expert = weight1.size(0);
    const auto out_feat = weight2.size(1);
	const auto hidden_feat = weight1.size(1);
    const auto in_feat = weight1.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
255
            
Rick Ho's avatar
Rick Ho committed
256
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
257
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
258
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
259
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
260
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
261
262
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
263
264
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
265
266
                    weight1.data_ptr<scalar_t>(),
                    weight2.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
267
268
269
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Rick Ho's avatar
Rick Ho committed
270
					hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
271
                    out_feat,
Rick Ho's avatar
Rick Ho committed
272
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
273
274
275
276
277
278
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
279
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
280
281
282
283
284
285
286
287
288
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
289
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
290
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
291
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
292
293
294

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
295
296

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
297
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
298
299
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
300
301
302
303
304
305
306
307
308
309
310
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
311
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
312
313
314
315
316
317
318
319
320

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
321
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
322
323
324
325
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
326
327
328
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
329
330

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
331
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
332
333
334
335
336
337
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
338
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
340
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
341

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
342
343
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
344
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
345
346
347
348
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
349

Jiezhong Qiu's avatar
Jiezhong Qiu committed
350
351
352
353
354
355
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
356
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
357
358
359
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
360
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
361
362
363
364
365
366
367
368
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
369
}
Rick Ho's avatar
Rick Ho committed
370
*/