"sgl-kernel/pyproject_cpu.toml" did not exist on "496dde849180d1a9275ab04ad6d739090f20e52a"
moe_cuda_kernel.cu 7.92 KB
Newer Older
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
1
2
#include <torch/extension.h>
#include <torch/torch.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
Jiezhong Qiu committed
7

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
8
9
10
11
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>                                                                                          
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
13
// #include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14

Rick Ho's avatar
Rick Ho committed
15
16
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17

Rick Ho's avatar
Rick Ho committed
18
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
21
22
template <typename scalar_t>
__global__
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
23
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
26
27
28
29
	size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
	if (idx < n) {
		ptrs[idx] = base + stride * offset[idx];
	}
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
30

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
template <typename scalar_t>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
32
void moe_cuda_forward_impl(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
        const scalar_t* input,
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
34
        const int* gate,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
35
36
        const scalar_t* weight,
        scalar_t* output,
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
37
38
        const size_t batch_size,
        const size_t in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
39
40
41
        const size_t out_feat,
        const size_t num_expert,
        cublasOperation_t transb) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
42

Rick Ho's avatar
Rick Ho committed
43
    auto* h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44
45

    checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47

    // setup Aarray, Barray and Carray
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
	std::vector<const scalar_t*> aptrs;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
51
52
53
    std::vector<scalar_t*> cptrs;
	
    const scalar_t **Aarray;
    const scalar_t **Barray;
    scalar_t **Carray;
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
54
55
56
	checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
    checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
58

	for (size_t i=0; i<batch_size; ++i) {
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
59
60
        aptrs.push_back(input + in_feat * i);
        cptrs.push_back(output + out_feat * i);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
61
	}
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
					scalar_t*), cudaMemcpyHostToDevice));
	// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
	// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
				sizeof(scalar_t*), cudaMemcpyHostToDevice));

	dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
	generate_ptr_offset_kernel<<<griddim, blockdim, 0,
		*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);

	scalar_t alpha = 1, beta = 0; 

	checkCudaErrors(cublasXgemmBatched(h->handle,
				CUBLAS_OP_N,
				transb,
				1, out_feat, in_feat,
				&alpha,
				Aarray, 1,
				Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
				&beta,
				Carray, 1,
				batch_size));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
85

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86
	checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
87
88
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
89
90
91
92
93
94
95
96
97
template <typename scalar_t>
void moe_cuda_grad_weight(
        const scalar_t* input,
        const int* gate,
        const scalar_t* grad_output,
        scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98
        const size_t num_expert) {
Jiezhong Qiu's avatar
Jiezhong Qiu committed
99

Rick Ho's avatar
Rick Ho committed
100
    auto h = getCudaStreamManager(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
101
102
103
104
105
106
    
    int* gate_host = new int[batch_size];
    scalar_t alpha = 1, beta = 1;
    checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
    for (size_t i=0; i<batch_size; ++i) {
        checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
107
        checkCudaErrors(cublasXgemm(h->handle,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
108
            CUBLAS_OP_N, 
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
109
            CUBLAS_OP_T,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
110
111
112
113
114
115
116
            out_feat, 
            in_feat, 
            1,
            &alpha,
            grad_output + i * out_feat,
            out_feat,
            input + i * in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
117
            in_feat,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
118
119
120
121
            &beta,
            grad_weight + gate_host[i] * out_feat * in_feat,
            out_feat));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
122
123
124
    for (size_t i=0; i<num_expert; ++i) {
        checkCudaErrors(cudaStreamSynchronize(*(h->streams + i)));
    }
Jiezhong Qiu's avatar
Jiezhong Qiu committed
125
126
    delete[] gate_host;
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
127

Jiezhong Qiu's avatar
Jiezhong Qiu committed
128
std::vector<torch::Tensor> moe_cuda_forward(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
129
130
131
132
133
134
135
136
        torch::Tensor input,
        torch::Tensor gate,
        torch::Tensor weight) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
            
Rick Ho's avatar
Rick Ho committed
137
#ifdef MOE_DEBUG
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
138
    printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
139
#endif
Jiezhong Qiu's avatar
topk=1  
Jiezhong Qiu committed
140
    auto output = input.new_zeros({batch_size, out_feat});
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
141
    
Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
143
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
                moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
144
145
146
147
148
149
                    input.data_ptr<scalar_t>(),
                    gate.data_ptr<int>(),
                    weight.data_ptr<scalar_t>(),
                    output.data_ptr<scalar_t>(),
                    batch_size,
                    in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
150
151
152
                    out_feat,
                    num_expert,
                    CUBLAS_OP_T
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
153
154
155
156
157
158
                );
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
159
std::vector<torch::Tensor> moe_cuda_backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
160
161
162
163
164
165
166
167
168
    torch::Tensor grad_output, // [batch_size x out_feat]
    torch::Tensor input, // [batch_size x out_feat]
    torch::Tensor gate,  // [batch_size]
    torch::Tensor weight // [num_expert x out_feat x in_feat]
) {
    const auto batch_size = input.size(0);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
169
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
170
171
172

    auto grad_input = grad_output.new_zeros({batch_size, in_feat});  // batch_size x in_feat
    auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
173
174
175
176

    // grad_input is easy to compute, exactly the same as forward
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
177
178
179
180
181
182
183
184
185
186
187
            grad_output.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            weight.data_ptr<scalar_t>(),
            grad_input.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            in_feat,
            num_expert,
            CUBLAS_OP_N
        );
    }));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
188
189
190
191
192
193
194
195
196

    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
        moe_cuda_grad_weight<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<int>(),
            grad_output.data_ptr<scalar_t>(),
            grad_weight.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
197
            out_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
198
199
200
201
            num_expert
        );
    }));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
202
203
204
    return {grad_input, grad_weight};
}

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
205
206

/*
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
207
int main() {
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
208
209
210
211
212
213
    typedef float data_t;
    size_t batch_size = 4096;
    size_t top_k = 2;
    size_t num_expert = 128;
    size_t in_feat = 1024;
    size_t out_feat = 4096;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
214
	data_t *input, *weight;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
215
	data_t *output;
Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
216
	size_t *gate;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
217

Jiezhong Qiu's avatar
updatre  
Jiezhong Qiu committed
218
219
	checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
	checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));	
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
220
	checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
Jiezhong Qiu's avatar
Jiezhong Qiu committed
221
222
223
224
    checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
    
    size_t nt = 16;
    double tsum = 0, tmax = 0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
225

Jiezhong Qiu's avatar
Jiezhong Qiu committed
226
227
228
229
230
231
    size_t *gate_host = new size_t[batch_size * top_k];
    for (size_t i=0; i<batch_size * top_k; ++i) {
        gate_host[i] = rand() % num_expert;
    } 
    checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
232
    moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
233
234
235
    
    for (size_t i=0; i<nt; ++i) {
        timestamp(start);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
236
		moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
237
238
239
240
241
242
243
244
		timestamp(end);
		auto t = getDuration(start, end);
		tsum += t;
		if (t > tmax) tmax = t;
    }
    printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
	double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
	printf("%.3lf TFLOPs\n", tflops);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
245
}
Rick Ho's avatar
Rick Ho committed
246
*/