moe_cuda_kernel.cu 9.91 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

Rick Ho's avatar
Rick Ho committed
20
// #define MOE_BREAKDOWN
21
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
25
26
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
29
30
31
32
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
60
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
61
        const scalar_t* weight,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
62
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
63
64
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
65
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
66
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67

Rick Ho's avatar
Rick Ho committed
68
69
    auto h = getCudaStreamManager(num_expert);

Rick Ho's avatar
Rick Ho committed
70
71
72
73
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

Rick Ho's avatar
Rick Ho committed
74
	scalar_t *input_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
75

Rick Ho's avatar
Rick Ho committed
76
77
78
79
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Rick Ho's avatar
Rick Ho committed
80
81
82
83
84
85

#ifdef MOE_BREAKDOWN
	timestamp(t_malloc);
	fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
			1e6);
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86

Rick Ho's avatar
Rick Ho committed
87
88
89
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
90

Rick Ho's avatar
Rick Ho committed
91
92
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99

#ifdef MOE_BREAKDOWN
	timestamp(t_cpy);
	fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
105
106
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
107

108
109
110
111
112
113
114
115
116
117
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
118
119
120
121
122
123
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
			1e6);
#endif

124
125
126
127
	batch_scatter_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
				input_buf); 
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
128

Rick Ho's avatar
Rick Ho committed
129
130
131
132
133
134
135
#ifdef MOE_BREAKDOWN
	h->sync();
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
136
137
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
138
139
140
141
142
143
144
145
146
147
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
148
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
149
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
150
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
151
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
152
				&alpha,
Rick Ho's avatar
Rick Ho committed
153
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
154
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
155
				&beta,
Rick Ho's avatar
Rick Ho committed
156
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
157
				));
Rick Ho's avatar
Rick Ho committed
158

Rick Ho's avatar
Rick Ho committed
159
160
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
161
162
163
164
165
166
167

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
168
	h->sync();
169
170
171
172
	batch_gather_kernel<scalar_t>
		<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
				output); 
	h->sync(0);
Rick Ho's avatar
Rick Ho committed
173

Rick Ho's avatar
Rick Ho committed
174
175
176
177
178
179
180
181
#ifdef MOE_BREAKDOWN
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
182
183
	cudaFree(input_buf);
	cudaFree(output_buf);
184
185
186
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
187
188
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
189
190
191
192
193
194
195
196
197
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
198
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
199

Rick Ho's avatar
Rick Ho committed
200
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
201
202
203
204
205
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
206
207
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
208
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
210
211
212
213
214
215
216
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
217
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
218
219
220
221
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
222
223
224
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
225
226
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227

Jiezhong Qiu's avatar
Jiezhong Qiu committed
228
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
229
230
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
231
        torch::Tensor weight
Rick Ho's avatar
Rick Ho committed
232
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
233
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
234
235
236
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
237
            
Rick Ho's avatar
Rick Ho committed
238
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
239
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
240
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
241
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
242
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
243
244
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
245
246
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
247
                    weight.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
248
249
250
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
251
                    out_feat,
Rick Ho's avatar
Rick Ho committed
252
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
253
254
255
256
257
258
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
259
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
260
261
262
263
264
265
266
267
268
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
269
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
270
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
271
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
272
273
274

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
275
276

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
277
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
278
279
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
280
281
282
283
284
285
286
287
288
289
290
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
291
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
292
293
294
295
296
297
298
299
300

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
301
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
302
303
304
305
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
306
307
308
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
309
310

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
311
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
312
313
314
315
316
317
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
318
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
319
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
320
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
321

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
322
323
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
324
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
325
326
327
328
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
329

Jiezhong Qiu's avatar
Jiezhong Qiu committed
330
331
332
333
334
335
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
336
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
337
338
339
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
340
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
341
342
343
344
345
346
347
348
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
349
}
Rick Ho's avatar
Rick Ho committed
350
*/