"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "325a06c2deef25067f8b37e73358b4569e13def7"
moe_cuda_kernel.cu 10.2 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Rick Ho's avatar
Rick Ho committed
13
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19

Rick Ho's avatar
Rick Ho committed
20
21
// #define MOE_BREAKDOWN
#define MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
25
26
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
		const int* offset, const scalar_t** ptrs) { 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
29
30
31
32
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
33

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
34
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
35
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
36
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
37
        const int* d_gate,
Rick Ho's avatar
Rick Ho committed
38
39
        const scalar_t* weight1,
        const scalar_t* weight2,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
41
42
        const size_t batch_size,
        const size_t in_feat,
Rick Ho's avatar
Rick Ho committed
43
        const size_t hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44
        const size_t out_feat,
Rick Ho's avatar
Rick Ho committed
45
        const size_t num_expert) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46

Rick Ho's avatar
Rick Ho committed
47
48
    auto h = getCudaStreamManager(num_expert);

Rick Ho's avatar
Rick Ho committed
49
50
51
52
53
#ifdef MOE_BREAKDOWN
	timestamp(t_init);
#endif

	scalar_t *input_buf, *hidden_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
54

Rick Ho's avatar
Rick Ho committed
55
56
57
58
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Rick Ho's avatar
Rick Ho committed
59
60
61
62
63
64
65
66
	checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
				hidden_feat));

#ifdef MOE_BREAKDOWN
	timestamp(t_malloc);
	fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
			1e6);
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67

Rick Ho's avatar
Rick Ho committed
68
69
70
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
71

Rick Ho's avatar
Rick Ho committed
72
73
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
Rick Ho's avatar
Rick Ho committed
74
75
76
77
78
79
80

#ifdef MOE_BREAKDOWN
	timestamp(t_cpy);
	fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
87
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
Rick Ho's avatar
Rick Ho committed
88
89
90
91
92
93
94

#ifdef MOE_BREAKDOWN
	timestamp(t_expert);
	fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101
102
	for (int i = 0; i < batch_size; ++i) {
		int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
		checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat, 
					input + i * in_feat, sizeof(scalar_t) * in_feat,
					cudaMemcpyDeviceToDevice,
Rick Ho's avatar
Rick Ho committed
103
					h->getStream(gate[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
104
	}
Rick Ho's avatar
Rick Ho committed
105

Rick Ho's avatar
Rick Ho committed
106
107
108
109
110
111
112
#ifdef MOE_BREAKDOWN
	h->sync();
	timestamp(t_scatter);
	fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
113
114
	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
115
116
117
118
119
120
121
122
123
124
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
125
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
126
				CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
127
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
128
				hidden_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
129
				&alpha,
Rick Ho's avatar
Rick Ho committed
130
				weight1 + i * in_feat * hidden_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
131
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
132
				&beta,
Rick Ho's avatar
Rick Ho committed
133
134
135
136
137
138
139
140
141
142
143
144
				hidden_buf + hidden_feat * ptr, hidden_feat
				));

		checkCudaErrors(cublasXgemm(h->getHandle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
				out_feat, expert_count[i], hidden_feat,
				&alpha,
				weight2 + i * hidden_feat * out_feat, hidden_feat,
				hidden_buf + hidden_feat * ptr, hidden_feat,
				&beta,
				output_buf + out_feat * ptr, out_feat
Rick Ho's avatar
Rick Ho committed
145
				));
Rick Ho's avatar
Rick Ho committed
146

Rick Ho's avatar
Rick Ho committed
147
148
		ptr += expert_count[i];
	}
Rick Ho's avatar
Rick Ho committed
149
150
151
152
153
154
155
156

#ifdef MOE_BREAKDOWN
	h->sync();
	timestamp(t_mm);
	fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
157
158
159
160
161
162
163
164
165
	for (int i = batch_size - 1; i >= 0; --i) {
		int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
		checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
					output_buf + target_idx * out_feat,
					sizeof(scalar_t) * out_feat,
					cudaMemcpyDeviceToDevice,
Rick Ho's avatar
Rick Ho committed
166
					h->getStream(gate[i])));
Rick Ho's avatar
Rick Ho committed
167
	}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
168

Rick Ho's avatar
Rick Ho committed
169
170
	h->sync();

Rick Ho's avatar
Rick Ho committed
171
172
173
174
175
176
177
178
#ifdef MOE_BREAKDOWN
	timestamp(t_gather);
	fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
			1e6);
	fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
			1e6);
#endif

Rick Ho's avatar
Rick Ho committed
179
180
	cudaFree(input_buf);
	cudaFree(output_buf);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
181
182
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
183
184
185
186
187
188
189
190
191
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
192
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
193

Rick Ho's avatar
Rick Ho committed
194
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
195
196
197
198
199
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
200
201
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
202
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
203
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
204
205
206
207
208
209
210
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
211
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
212
213
214
215
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
216
217
218
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
219
220
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
221

Jiezhong Qiu's avatar
Jiezhong Qiu committed
222
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
223
224
        torch::Tensor input,
        torch::Tensor gate,
Rick Ho's avatar
Rick Ho committed
225
226
227
        torch::Tensor weight1,
        torch::Tensor weight2
		) {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
228
    const auto batch_size = input.size(0);
Rick Ho's avatar
Rick Ho committed
229
230
231
232
    const auto num_expert = weight1.size(0);
    const auto out_feat = weight2.size(1);
	const auto hidden_feat = weight1.size(1);
    const auto in_feat = weight1.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
233
            
Rick Ho's avatar
Rick Ho committed
234
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
235
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
236
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
237
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
238
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
239
240
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
241
242
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
Rick Ho's avatar
Rick Ho committed
243
244
                    weight1.data_ptr<scalar_t>(),
                    weight2.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
245
246
247
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Rick Ho's avatar
Rick Ho committed
248
					hidden_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
249
                    out_feat,
Rick Ho's avatar
Rick Ho committed
250
                    num_expert
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
251
252
253
254
255
256
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
257
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
258
259
260
261
262
263
264
265
266
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
267
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
268
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
269
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
270
271
272

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
273
274

    // grad_input is easy to compute, exactly the same as forward
Rick Ho's avatar
Rick Ho committed
275
	/* TODO: Backward currently brokenn
Jiezhong Qiu's avatar
Jiezhong Qiu committed
276
277
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
278
279
280
281
282
283
284
285
286
287
288
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Rick Ho's avatar
Rick Ho committed
289
	*/
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
290
291
292
293
294
295
296
297
298

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
299
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
300
301
302
303
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
304
305
306
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
307
308

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
309
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
310
311
312
313
314
315
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
316
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
317
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
318
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
319

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
320
321
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
322
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
323
324
325
326
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
327

Jiezhong Qiu's avatar
Jiezhong Qiu committed
328
329
330
331
332
333
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
334
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
335
336
337
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
338
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
339
340
341
342
343
344
345
346
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
347
}
Rick Ho's avatar
Rick Ho committed
348
*/