moe_cuda_kernel.cu 10.2 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20

Rick Ho's avatar
Rick Ho committed
21
// #define MOE_BREAKDOWN
22
// #define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23

24
thread_local CudaStreamManager smgr;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
28
29
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34
35
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * blockIdx.x;
	oubuf += wid * pos[blockIdx.x];
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}

template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos, 
		const scalar_t* inbuf, scalar_t* oubuf) { 
	inbuf += wid * pos[blockIdx.x];
	oubuf += wid * blockIdx.x;
	for (int i = threadIdx.x; i < wid; i += blockDim.x) {
		oubuf[i] = inbuf[i];
	}
}


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
61
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
62
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
63
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
64
        const scalar_t* weight,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
65
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
66
67
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
68
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
69
70
        const size_t num_expert, 
		cublasOperation_t transb) {
Rick Ho's avatar
Rick Ho committed
71

Rick Ho's avatar
Rick Ho committed
72
73
74
75
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

Rick Ho's avatar
Rick Ho committed
76
	scalar_t *input_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77

Rick Ho's avatar
Rick Ho committed
78
79
80
81
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Rick Ho's avatar
Rick Ho committed
82
83
84
85
86
87

#ifdef MOE_BREAKDOWN
	timestamp(t_malloc);
	fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
			1e6);
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
88

Rick Ho's avatar
Rick Ho committed
89
90
91
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
92

Rick Ho's avatar
Rick Ho committed
93
94
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101

#ifdef MOE_BREAKDOWN
	timestamp(t_cpy);
	fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
102
103
104
105
106
107
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
Rick Ho's avatar
Rick Ho committed
108
	}
Rick Ho's avatar
Rick Ho committed
109

110
111
112
113
114
115
116
117
118
119
	int *pos = new int[batch_size];
	int *d_pos;
	checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));

	for (int i = 0; i < batch_size; ++i) {
		pos[i] = expert_ptr[gate[i]]++;
	}
	checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
				cudaMemcpyHostToDevice));

Rick Ho's avatar
Rick Ho committed
120
121
122
123
124
125
#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
			1e6);
#endif

126
	batch_scatter_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
127
		<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
128
				input_buf); 
Rick Ho's avatar
Rick Ho committed
129
	// smgr.sync(0);
Rick Ho's avatar
Rick Ho committed
130

Rick Ho's avatar
Rick Ho committed
131
#ifdef MOE_BREAKDOWN
Rick Ho's avatar
Rick Ho committed
132
	// h->sync();
Rick Ho's avatar
Rick Ho committed
133
134
135
136
137
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
138
139
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
140
141
142
143
144
145
146
147
148
149
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
150
		checkCudaErrors(cublasXgemm(smgr.handles[0], // h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
151
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
152
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
153
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
154
				&alpha,
Rick Ho's avatar
Rick Ho committed
155
				weight + i * in_feat * out_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
156
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
157
				&beta,
Rick Ho's avatar
Rick Ho committed
158
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
159
				));
Rick Ho's avatar
Rick Ho committed
160

Rick Ho's avatar
Rick Ho committed
161
162
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
163
164
165
166
167
168
169

#ifdef MOE_BREAKDOWN
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
170
	// h->sync();
171
	batch_gather_kernel<scalar_t>
Rick Ho's avatar
Rick Ho committed
172
		<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
173
				output); 
Rick Ho's avatar
Rick Ho committed
174
	// h->sync(0);
Rick Ho's avatar
Rick Ho committed
175

Rick Ho's avatar
Rick Ho committed
176
177
178
179
180
181
182
183
#ifdef MOE_BREAKDOWN
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
184
185
	cudaFree(input_buf);
	cudaFree(output_buf);
186
187
188
	cudaFree(d_pos);
	delete [] pos;
	delete [] gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
189
190
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
191
192
193
194
195
196
197
198
199
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
200
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
201
202
203
204
205

    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
206
        // checkCudaErrors(cublasSetStream);
Rick Ho's avatar
Rick Ho committed
207
        checkCudaErrors(cublasXgemm(smgr.handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
208
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
210
211
212
213
214
215
216
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
217
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
218
219
220
221
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
222
    for (size_t i=0; i<num_expert; ++i) {
223
        checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
224
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
225
226
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227

Jiezhong Qiu's avatar
Jiezhong Qiu committed
228
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
229
230
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
231
        torch::Tensor weight
Rick Ho's avatar
Rick Ho committed
232
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
233
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
234
235
236
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
237
            
Rick Ho's avatar
Rick Ho committed
238
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
239
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
240
#endif
241
242
243
244
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
245
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
246
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
247
248
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
249
250
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
251
                    weight.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
252
253
254
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
255
                    out_feat,
Rick Ho's avatar
Rick Ho committed
256
257
                    num_expert,
					CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
258
259
260
261
262
263
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
264
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
265
266
267
268
269
270
271
272
273
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
274

Rick Ho's avatar
Rick Ho committed
275
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
276
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
277
#endif
278
279
280
281
    const int device = device_of(input).value().index();
    if (smgr.streams == NULL) {
        smgr.setup(num_expert, device);
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
282
283
284

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
285
286

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
287
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
288
289
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
290
291
292
293
294
295
296
297
298
299
300
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
301
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
302
303
304
305
306
307
308
309
310

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
311
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
312
313
314
315
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
316
317
318
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
319
320

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
321
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
322
323
324
325
326
327
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
328
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
329
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
330
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
331

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
332
333
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
334
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
335
336
337
338
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
339

Jiezhong Qiu's avatar
Jiezhong Qiu committed
340
341
342
343
344
345
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
346
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
347
348
349
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
350
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
351
352
353
354
355
356
357
358
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
359
}
Rick Ho's avatar
Rick Ho committed
360
*/