moe_cuda_kernel.cu 8.32 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
14
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
15

Rick Ho's avatar
Rick Ho committed
16
17
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Rick Ho's avatar
Rick Ho committed
19
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20
21


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22
23
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
24
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
26
27
28
29
30
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
31

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
33
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
34
        const scalar_t* input,
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
35
        const int* gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
36
37
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
38
39
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
41
        const size_t out_feat,
        const size_t num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
43
        cublasOperation_t transb,
        const int device) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44

Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
    auto* h = getCudaStreamManager(num_expert, device);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47

    checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
49

    // setup Aarray, Barray and Carray
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50
	std::vector<const scalar_t*> aptrs;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
52
53
54
55
    std::vector<scalar_t*> cptrs;
	
    const scalar_t **Aarray;
    const scalar_t **Barray;
    scalar_t **Carray;
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
56
57
58
	checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59
60

	for (size_t i=0; i<batch_size; ++i) {
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
61
62
        aptrs.push_back(input + in_feat * i);
        cptrs.push_back(output + out_feat * i);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
63
	}
Rick Ho's avatar
Rick Ho committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
	checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
					scalar_t*), cudaMemcpyHostToDevice));
	// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
	// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
				sizeof(scalar_t*), cudaMemcpyHostToDevice));

	dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
	generate_ptr_offset_kernel<<<griddim, blockdim, 0,
		*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);

	scalar_t alpha = 1, beta = 0; 

	checkCudaErrors(cublasXgemmBatched(h->handle,
				CUBLAS_OP_N,
				transb,
				1, out_feat, in_feat,
				&alpha,
				Aarray, 1,
				Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
				&beta,
				Carray, 1,
				batch_size));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
87

Jiezhong Qiu's avatar
Jiezhong Qiu committed
88
89
90
91
    checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
    checkCudaErrors(cudaFree(Aarray));
    checkCudaErrors(cudaFree(Barray));
    checkCudaErrors(cudaFree(Carray));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
92
93
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
94
95
96
97
98
99
100
101
102
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
103
104
        const size_t num_expert,
        const int device) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
105

Jiezhong Qiu's avatar
Jiezhong Qiu committed
106
    auto h = getCudaStreamManager(num_expert, device);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
107
108
109
110
111
112
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
        checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
113
        checkCudaErrors(cublasXgemm(h->handle,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
114
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
115
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
116
117
118
119
120
121
122
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
123
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
124
125
126
127
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
128
129
130
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
131
132
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
133

Jiezhong Qiu's avatar
Jiezhong Qiu committed
134
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
135
136
137
138
139
140
141
142
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
143
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
144
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
145
#endif
Jiezhong Qiu's avatar
Jiezhong Qiu committed
146
    int device = device_of(input).value().index();
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
147
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
148
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
150
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
151
152
153
154
155
156
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
157
158
                    out_feat,
                    num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
160
                    CUBLAS_OP_T,
                    device
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
161
162
163
164
165
166
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
167
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
168
169
170
171
172
173
174
175
176
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
177

Jiezhong Qiu's avatar
Jiezhong Qiu committed
178
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
179
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
180
#endif
Jiezhong Qiu's avatar
Jiezhong Qiu committed
181
    int device = device_of(input).value().index();
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
182
183
    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
184
185
186
187

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
188
189
190
191
192
193
194
195
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
196
197
            CUBLAS_OP_N,
            device
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
198
199
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
200
201
202
203
204
205
206
207
208

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
            out_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
210
211
            num_expert,
            device
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
212
213
214
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
215
216
217
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
218
219

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
220
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
221
222
223
224
225
226
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
227
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
228
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
229
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
230

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
231
232
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
233
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
234
235
236
237
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
238

Jiezhong Qiu's avatar
Jiezhong Qiu committed
239
240
241
242
243
244
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
245
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
246
247
248
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
249
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
250
251
252
253
254
255
256
257
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
258
}
Rick Ho's avatar
Rick Ho committed
259
*/