"vscode:/vscode.git/clone" did not exist on "92c20fdae643d70a33d1a7e3b34f6a1338cf5e44"
moe_cuda_kernel.cu 8.2 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
14
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20
21


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22
23
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
24
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
26
27
28
29
30
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
31

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
33
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
34
        const scalar_t* input,
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
35
        const int* gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
36
37
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
38
39
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
41
        const size_t out_feat,
        const size_t num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
43
        cublasOperation_t transb,
        const int device) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44

Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
    auto* h = getCudaStreamManager(num_expert, device);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47

    checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
49

    // setup Aarray, Barray and Carray
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50
	std::vector<const scalar_t*> aptrs;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
52
53
54
55
    std::vector<scalar_t*> cptrs;
	
    const scalar_t **Aarray;
    const scalar_t **Barray;
    scalar_t **Carray;
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
56
57
58
	checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59
60

	for (size_t i=0; i<batch_size; ++i) {
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
61
62
        aptrs.push_back(input + in_feat * i);
        cptrs.push_back(output + out_feat * i);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
63
	}
Rick Ho's avatar
Rick Ho committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
	checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
					scalar_t*), cudaMemcpyHostToDevice));
	// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
	// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
				sizeof(scalar_t*), cudaMemcpyHostToDevice));

	dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
	generate_ptr_offset_kernel<<<griddim, blockdim, 0,
		*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);

	scalar_t alpha = 1, beta = 0; 

	checkCudaErrors(cublasXgemmBatched(h->handle,
				CUBLAS_OP_N,
				transb,
				1, out_feat, in_feat,
				&alpha,
				Aarray, 1,
				Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
				&beta,
				Carray, 1,
				batch_size));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
87

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
88
	checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
89
90
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
91
92
93
94
95
96
97
98
99
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
100
101
        const size_t num_expert,
        const int device) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
102

Jiezhong Qiu's avatar
Jiezhong Qiu committed
103
    auto h = getCudaStreamManager(num_expert, device);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
104
105
106
107
108
109
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
        checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
110
        checkCudaErrors(cublasXgemm(h->handle,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
111
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
112
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
113
114
115
116
117
118
119
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
120
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
121
122
123
124
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
125
126
127
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
128
129
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
130

Jiezhong Qiu's avatar
Jiezhong Qiu committed
131
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
132
133
134
135
136
137
138
139
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
140
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
141
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
142
#endif
Jiezhong Qiu's avatar
Jiezhong Qiu committed
143
    int device = device_of(input).value().index();
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
144
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
145
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
146
147
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
148
149
150
151
152
153
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
154
155
                    out_feat,
                    num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
156
157
                    CUBLAS_OP_T,
                    device
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
158
159
160
161
162
163
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
164
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
165
166
167
168
169
170
171
172
173
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
174

Jiezhong Qiu's avatar
Jiezhong Qiu committed
175
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
176
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
177
#endif
Jiezhong Qiu's avatar
Jiezhong Qiu committed
178
    int device = device_of(input).value().index();
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
179
180
    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
181
182
183
184

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
185
186
187
188
189
190
191
192
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
193
194
            CUBLAS_OP_N,
            device
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
195
196
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
197
198
199
200
201
202
203
204
205

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
206
            out_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
207
208
            num_expert,
            device
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
210
211
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
212
213
214
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
215
216

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
217
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
218
219
220
221
222
223
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
224
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
225
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
226
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
228
229
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
230
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
231
232
233
234
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
235

Jiezhong Qiu's avatar
Jiezhong Qiu committed
236
237
238
239
240
241
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
242
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
243
244
245
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
246
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
247
248
249
250
251
252
253
254
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
255
}
Rick Ho's avatar
Rick Ho committed
256
*/