moe_cuda_kernel.cu 7.94 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
13
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
21
22
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
23
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
26
27
28
29
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
30

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
32
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
        const scalar_t* input,
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
34
        const int* gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
35
36
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
37
38
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
39
40
41
        const size_t out_feat,
        const size_t num_expert,
        cublasOperation_t transb) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
42

Rick Ho's avatar
Rick Ho committed
43
    auto* h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44
45

    checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47

    // setup Aarray, Barray and Carray
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
	std::vector<const scalar_t*> aptrs;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
51
52
53
    std::vector<scalar_t*> cptrs;
	
    const scalar_t **Aarray;
    const scalar_t **Barray;
    scalar_t **Carray;
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
54
55
56
	checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
58

	for (size_t i=0; i<batch_size; ++i) {
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
59
60
        aptrs.push_back(input + in_feat * i);
        cptrs.push_back(output + out_feat * i);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
61
	}
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
					scalar_t*), cudaMemcpyHostToDevice));
	// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
	// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
				sizeof(scalar_t*), cudaMemcpyHostToDevice));

	dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
	generate_ptr_offset_kernel<<<griddim, blockdim, 0,
		*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);

	scalar_t alpha = 1, beta = 0; 

	checkCudaErrors(cublasXgemmBatched(h->handle,
				CUBLAS_OP_N,
				transb,
				1, out_feat, in_feat,
				&alpha,
				Aarray, 1,
				Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
				&beta,
				Carray, 1,
				batch_size));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
85

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86
	checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
87
88
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
89
90
91
92
93
94
95
96
97
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
99

Rick Ho's avatar
Rick Ho committed
100
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
101
102
103
104
105
106
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
        checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
107
        checkCudaErrors(cublasXgemm(h->handle,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
108
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
109
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
110
111
112
113
114
115
116
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
117
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
118
119
120
121
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
122
123
124
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
125
126
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
127

Jiezhong Qiu's avatar
Jiezhong Qiu committed
128
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
129
130
131
132
133
134
135
136
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
137
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
138
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
139
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
140
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
141
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
143
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
144
145
146
147
148
149
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
150
151
152
                    out_feat,
                    num_expert,
                    CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
153
154
155
156
157
158
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
160
161
162
163
164
165
166
167
168
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
169
170
    
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
171
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
172
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
173
174
    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
175
176
177
178

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
179
180
181
182
183
184
185
186
187
188
189
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
190
191
192
193
194
195
196
197
198

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
199
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
200
201
202
203
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
204
205
206
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
207
208

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
209
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
210
211
212
213
214
215
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
216
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
217
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
218
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
219

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
220
221
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
222
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
223
224
225
226
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
227

Jiezhong Qiu's avatar
Jiezhong Qiu committed
228
229
230
231
232
233
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
234
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
235
236
237
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
238
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
239
240
241
242
243
244
245
246
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
247
}
Rick Ho's avatar
Rick Ho committed
248
*/