moe_cuda_kernel.cu 8.66 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
13
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
21
22
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
23
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
26
27
28
29
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
30

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
32
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
        const scalar_t* input,
Rick Ho's avatar
Rick Ho committed
34
        const int* d_gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
35
36
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
37
38
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
39
40
41
        const size_t out_feat,
        const size_t num_expert,
        cublasOperation_t transb) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
42

Rick Ho's avatar
Rick Ho committed
43
44
45
    auto h = getCudaStreamManager(num_expert);

	scalar_t *input_buf, *output_buf;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46

Rick Ho's avatar
Rick Ho committed
47
48
49
50
	checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
				in_feat));
	checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
				out_feat));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51

Rick Ho's avatar
Rick Ho committed
52
53
54
    int *gate = new int[batch_size];
	int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
	memset(expert_count, 0, sizeof(int) * num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
55

Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
	checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
				cudaMemcpyDeviceToHost));
	for (int i = 0; i < batch_size; ++i) {
		++expert_count[gate[i]];
	}
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
	}
	for (int i = 0; i < batch_size; ++i) {
		int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
		checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat, 
					input + i * in_feat, sizeof(scalar_t) * in_feat,
					cudaMemcpyDeviceToDevice,
Rick Ho's avatar
Rick Ho committed
73
					h->getStream(gate[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
74
	}
Rick Ho's avatar
Rick Ho committed
75
76
77

	scalar_t alpha = 1, beta = 0; 

Rick Ho's avatar
Rick Ho committed
78
79
80
81
82
83
84
85
86
87
	for (int i = 0, ptr = 0; i < num_expert; ++i) {
		if (expert_count[i] == 0) {
			continue;
		}
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
		fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
				in_feat);
#endif
		// Use T(B) x T(A) = T(C) to produce row-major C
Rick Ho's avatar
Rick Ho committed
88
		checkCudaErrors(cublasXgemm(h->getHandle(i),
Rick Ho's avatar
Rick Ho committed
89
				(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T,
Rick Ho's avatar
Rick Ho committed
90
				CUBLAS_OP_N,
Rick Ho's avatar
Rick Ho committed
91
				out_feat, expert_count[i], in_feat,
Rick Ho's avatar
Rick Ho committed
92
				&alpha,
Rick Ho's avatar
Rick Ho committed
93
94
95
				weight + i * in_feat * out_feat, 
				(transb == CUBLAS_OP_T) ? out_feat : in_feat,
				input_buf + ptr * in_feat, in_feat,
Rick Ho's avatar
Rick Ho committed
96
				&beta,
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
				output_buf + out_feat * ptr,
				out_feat
				));
		ptr += expert_count[i];
	}
	for (int i = batch_size - 1; i >= 0; --i) {
		int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
		fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
		checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
					output_buf + target_idx * out_feat,
					sizeof(scalar_t) * out_feat,
					cudaMemcpyDeviceToDevice,
Rick Ho's avatar
Rick Ho committed
111
					h->getStream(gate[i])));
Rick Ho's avatar
Rick Ho committed
112
	}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
113

Rick Ho's avatar
Rick Ho committed
114
115
	h->sync();

Rick Ho's avatar
Rick Ho committed
116
117
	cudaFree(input_buf);
	cudaFree(output_buf);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
118
119
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
120
121
122
123
124
125
126
127
128
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
129
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
130

Rick Ho's avatar
Rick Ho committed
131
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
132
133
134
135
136
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
Rick Ho's avatar
Rick Ho committed
137
138
        checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
        checkCudaErrors(cublasXgemm(h->handles[0],
Jiezhong Qiu's avatar
Jiezhong Qiu committed
139
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
140
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
141
142
143
144
145
146
147
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
148
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
150
151
152
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
153
154
155
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
156
157
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
158

Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
160
161
162
163
164
165
166
167
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
168
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
169
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
170
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
171
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
172
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
173
174
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
175
176
177
178
179
180
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
181
182
183
                    out_feat,
                    num_expert,
                    CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
184
185
186
187
188
189
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
190
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
191
192
193
194
195
196
197
198
199
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Rick Ho's avatar
Rick Ho committed
200
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
201
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
202
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
203
204
205

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
206
207
208
209

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
210
211
212
213
214
215
216
217
218
219
220
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
221
222
223
224
225
226
227
228
229

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
230
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
231
232
233
234
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
235
236
237
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
238
239

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
240
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
241
242
243
244
245
246
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
247
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
248
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
249
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
250

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
251
252
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
253
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
254
255
256
257
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
258

Jiezhong Qiu's avatar
Jiezhong Qiu committed
259
260
261
262
263
264
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
265
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
266
267
268
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
269
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
270
271
272
273
274
275
276
277
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
278
}
Rick Ho's avatar
Rick Ho committed
279
*/