moe.cpp 11.8 KB
Newer Older
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#include <torch/extension.h>
#include <torch/torch.h>

#include <cstdio>
#include <iostream>
#include <vector>

// CUDA runtime                                                                                                           
#include <cuda_runtime.h>                                                                                                 
#include <cublas_v2.h>                                                                                                    
                                                                                                                            
// CUDA and CUBLAS functions                                                                                              
//#include <helper_functions.h>                                                                                             
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14
#include <helper_cuda.h> 
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
15
16


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
17
const int num_stream=512;
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const float           *alpha,
                                  const float           *Aarray[], int lda,
                                  const float           *Barray[], int ldb,
                                  const float           *beta,
                                  float           *Carray[], int ldc,
                                  int batchCount)
{
    return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount)
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const double           *alpha,
                                  const double           *Aarray[], int lda,
                                  const double           *Barray[], int ldb,
                                  const double           *beta,
                                  double           *Carray[], int ldc,
                                  int batchCount)
{
    return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount)
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const __half           *alpha,
                                  const __half           *Aarray[], int lda,
                                  const __half           *Barray[], int ldb,
                                  const __half           *beta,
                                  _half           *Carray[], int ldc,
                                  int batchCount)
{
    return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount)
}

template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input,
        const size_t* gate,
        const scalar_t* weight,
        scalar_t* output,
        size_t batch_size,
        size_t top_k,
        size_t in_feat,
        size_t out_feat) {
    

    cublasHandle_t handle;
    checkCudaErrors(cublasCreate(&handle));

    // setup Aarray, Barray and Carray
	std::vector<scalar_t*> aptrs, bptrs, cptrs;
	scalar_t **ptrs;
	checkCudaErrors(cudaMalloc(&ptrs, batch_size * sizeof(scalar_t*) * top_k * 3));
	for (size_t i=0; i<batch_size; ++i) {
        for (size_t k=0; k<top_k; ++k) {
            aptrs.push_back(input + in_feat * i);
            bptrs.push_back(weight + out_feat * in_feat * gate[i * top_k + k]);
            cptrs.push_back(output + out_feat * (i * top_k + k));
        }
	}
	checkCudaErrors(cudaMemcpy(ptrs, aptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
	checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k *  2, cptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));

    scalar_t alpha = 1, beta = 0;
	checkCudaErrors(cublasXgemmBatched(handle, 
			CUBLAS_OP_N,
			CUBLAS_OP_T,
			1, out_feat, in_feat,
			&alpha,
			ptrs, 1,
			ptrs + batch_size * top_k, out_feat,
			&beta,
			ptrs + batch_size * top_k * 2, 1,
			batch_size));
	cudaStreamSynchronize(st);
}


Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
107
void moe_cuda_forward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        torch::Tensor input, // [B x D_model]
        torch::Tensor gate,  // [B x K]
        torch::Tensor weight, // [N x D_ffn x D_model]
        ) {
    /*
        The bias term should have been merged into weight. Note the following fact that 
        Wx+b = [W b] [x]
                     [1]  
    */
    const auto batch_size = input.size(0);
    const auto top_k = gate.size(1);
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
    
    printf("b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d\n", batch_size, num_expert, d_model, d_ffn, top_k);
    auto output = input.new_zeros({batch_size, top_k, out_feat});

    AT_DISPATCH_FLOATING_TYPES(input.type(), "moe_cuda_forward", ([&] {
        moe_cuda_forward_impl<scalar_t>(
            input.data_ptr<scalar_t>(),
            gate.data_ptr<size_t>(),
            weight.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            batch_size,
            top_k,
            in_feat,
            out_feat
        );
    }));

    cublasHandle_t handle;
    checkCudaErrors(cublasCreate(&handle));
    
    cudaStream_t stream[num_stream];
    for (size_t i=0; i<num_stream; ++i) {
        checkCudaErrors(cudaStreamCreate(&stream[i]));
    }

    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));
    // Record the start event
    checkCudaErrors(cudaEventRecord(start, NULL));
    
    size_t s;
    for (size_t i=0; i<batch_size; ++i) {
        for (size_t j=0; j<num_expert; ++j) {
            s = (i * num_expert + j) % num_stream;
            // printf("i=%d j=%d goes to stream %d\n", i, j, s);
            checkCudaErrors(cublasSetStream(handle, stream[s]));
            if (input.scalar_type() == torch::ScalarType::Float) {
                float alpha = 1.0;
                float beta = 0.0;
                checkCudaErrors(cublasSgemm(handle, 
                    CUBLAS_OP_N, 
                    CUBLAS_OP_N,
                    1, // m
                    d_ffn, // n
                    d_model, // k
                    &alpha,
                    input[i].data_ptr<float>(),
                    1,
                    weight.index(gate[i][j]).data_ptr<float>(),
                    d_model,
                    &beta,
                    output[i][j].data_ptr<float>(),
                    1));
            } else {
                printf("only support float!!!\n");
            }
        }
    }
    // checkCudaErrors(cudaDeviceSynchronize());
    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, NULL));

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
    float msecPerMatrixMul = msecTotal / batch_size / num_expert;
    double flopsPerMatrixMul = 2.0 * (double)d_model * (double)d_ffn;
    double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
        printf(
            "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
            gigaFlops,
            msecPerMatrixMul,
            flopsPerMatrixMul);

    // std::cout << output << std::endl;
    
    for (size_t i=0; i<num_stream; ++i) {
        checkCudaErrors(cudaStreamDestroy(stream[i]));
    }
    checkCudaErrors(cublasDestroy(handle));
}


// std::vector<torch::Tensor> 
void moe_cuda_forward_v1(
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
212
213
214
215
216
217
218
219
220
        torch::Tensor input, // [B x D_model]
        torch::Tensor gate,  // [B x N]
        torch::Tensor weight, // [N x D_model x D_ffn]
        torch::Tensor bias // [N x D_ffn]
        ) {
    const auto batch_size = input.size(0);
    const auto num_expert = gate.size(1);
    const auto d_model = weight.size(1);
    const auto d_ffn = weight.size(2);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
221
    printf("b=%d, expert=%d, d_model=%d, d_ffn=%d\n", batch_size, num_expert, d_model, d_ffn);
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
222
    auto output = input.new_zeros({batch_size, num_expert, d_ffn});
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
223
    
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
224
225

    cublasHandle_t handle;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
226
227
    checkCudaErrors(cublasCreate(&handle));
    
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
228
229
    cudaStream_t stream[num_stream];
    for (size_t i=0; i<num_stream; ++i) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
230
        checkCudaErrors(cudaStreamCreate(&stream[i]));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
231
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
232

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
233
234
235
236
237
    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));
    // Record the start event
    checkCudaErrors(cudaEventRecord(start, NULL));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
238
239
240
241
242
    
    size_t s;
    for (size_t i=0; i<batch_size; ++i) {
        for (size_t j=0; j<num_expert; ++j) {
            s = (i * num_expert + j) % num_stream;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
243
244
            // printf("i=%d j=%d goes to stream %d\n", i, j, s);
            checkCudaErrors(cublasSetStream(handle, stream[s]));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
245
246
247
            if (input.scalar_type() == torch::ScalarType::Float) {
                float alpha = 1.0;
                float beta = 0.0;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
248
                checkCudaErrors(cublasSgemm(handle, 
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
249
250
                    CUBLAS_OP_N, 
                    CUBLAS_OP_N,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
251
252
253
                    1, // m
                    d_ffn, // n
                    d_model, // k
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
254
                    &alpha,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
255
                    input[i].data_ptr<float>(),
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
256
                    1,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
257
                    weight.index(gate[i][j]).data_ptr<float>(),
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
258
259
                    d_model,
                    &beta,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
260
261
                    output[i][j].data_ptr<float>(),
                    1));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
262
            } else {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
263
                printf("only support float!!!\n");
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
264
265
266
            }
        }
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
267
268
269
270
271
272
273
274
275
    // checkCudaErrors(cudaDeviceSynchronize());
    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, NULL));

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
276

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
277
278
279
280
281
282
283
284
285
286
287
    // Compute and print the performance
    float msecPerMatrixMul = msecTotal / batch_size / num_expert;
    double flopsPerMatrixMul = 2.0 * (double)d_model * (double)d_ffn;
    double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
        printf(
            "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
            gigaFlops,
            msecPerMatrixMul,
            flopsPerMatrixMul);

    // std::cout << output << std::endl;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
288
    
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
289
    for (size_t i=0; i<num_stream; ++i) {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
290
        checkCudaErrors(cudaStreamDestroy(stream[i]));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
291
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
292
    checkCudaErrors(cublasDestroy(handle));
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
293
294
295
296
297
298
299
300
301
302
303
304
}


// C++ interface

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


int main() {
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
305
    int device=2;
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
306
    torch::Tensor input = torch::randn({2048, 512}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
307
    torch::Tensor gate = torch::zeros({2048, 2}, torch::dtype(torch::kInt64));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
308
309
310
    torch::Tensor weight = torch::randn({2, 512, 2048}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
    torch::Tensor bias = torch::randn({2, 2048}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
    checkCudaErrors(cudaSetDevice(device));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
311
    moe_cuda_forward_v1(input, gate, weight, bias);
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
312
}