"vscode:/vscode.git/clone" did not exist on "856b8ec1318c49d0b76650374014dfba4b9f4257"
gpt_gemm_func.cc 33.9 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

lvhan028's avatar
lvhan028 committed
17
#include "src/turbomind/utils/gemm_test/gpt_gemm_func.h"
Li Zhang's avatar
Li Zhang committed
18

lvhan028's avatar
lvhan028 committed
19
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

bool isSparseGemmAvailable(size_t m, size_t n, size_t k)
{
    return m % 8 == 0 && n % 8 == 0 && k % 8 == 0;
}

template<typename T>
void generate_gpt_gemm_config(int   batch_size,
                              int   beam_width,
                              int   max_input_len,
                              int   head_num,
                              int   size_per_head,
                              int   inter_size,
                              int   vocab_size,
                              int   tensor_para_size,
                              void* buffer_in,
                              bool  isAppend)
{
    FT_CHECK(head_num % tensor_para_size == 0);
    void* cublas_workspace;
    void* buffer;
    int   workSpaceSize;
Li Zhang's avatar
Li Zhang committed
42
#if 0
Li Zhang's avatar
Li Zhang committed
43
44
45
46
47
48
49
    bool  workspace_flag = std::is_same<T, half>::value;
#ifdef ENABLE_FP8
    workspace_flag = workspace_flag || std::is_same<T, __nv_fp8_e4m3>::value;
#endif
#if ENABLE_BF16
    workspace_flag = workspace_flag || std::is_same<T, __nv_bfloat16>::value;
#endif
Li Zhang's avatar
Li Zhang committed
50
51
52
#endif
    // algorithms with workspace perform worse than evaluated
    const bool workspace_flag = 0;
Li Zhang's avatar
Li Zhang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    if (workspace_flag) {
        // cublas_workspace_ should be the start pointer of cudaMalloc()
        // to ensure 16B alignemnet
        cublas_workspace = buffer_in;
        buffer           = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE);
        workSpaceSize    = CUBLAS_WORKSPACE_SIZE;
    }
    else {
        cublas_workspace = nullptr;
        buffer           = buffer_in;
        workSpaceSize    = 0;
    }

    struct cudaDeviceProp prop;
    check_cuda_error(cudaGetDeviceProperties(&prop, 0));
    printf("Device %s\n", prop.name);

    // check config
    FILE* fd;
    int   line_count = 0;
    if (!isAppend) {
        fd = fopen(GEMM_CONFIG, "w+");
    }
    else {
        fd = fopen(GEMM_CONFIG, "a+");
        std::vector<std::string> config;
        char                     line[1024];
        while (fgets(line, 1024, fd) != NULL) {
            config.push_back(std::string(line));
        }
        line_count = config.size();
        // if (config.size() >= (MAX_CONFIG_NUM * GEMM_NUM + 1))  // 6 cublas/cublasLt, first row is not included
        // {
        //     int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * GEMM_NUM);
        //     fclose(fd);
        //     fd = fopen(GEMM_CONFIG, "w+");
        //     fprintf(fd, "%s", config[0].c_str());
        //     for (uint i = startIdx; i < config.size(); i++) {
        //         fprintf(fd, "%s", config[i].c_str());
        //     }
        //     line_count = config.size() - (GEMM_NUM + 3);
        // }
    }

    const int hidden_units         = head_num * size_per_head;
    const int local_head_num       = head_num / tensor_para_size;
    const int local_hidden_units   = local_head_num * size_per_head;
    const int max_input_len_padded = (max_input_len + 15) / 16 * 16;
    const int gemm_num             = 11;
    int       M[gemm_num];
    int       N[gemm_num];
    int       K[gemm_num];
    int       batchCount[gemm_num];
    int64_t   strideA[gemm_num];
    int64_t   strideB[gemm_num];
    int64_t   strideD[gemm_num];
    char      mess[gemm_num][256];
    float     exec_times[gemm_num];

    // gemm 0
    M[0]          = batch_size * beam_width * max_input_len;
    K[0]          = hidden_units;
    N[0]          = 3 * local_hidden_units;
    batchCount[0] = 1;
    strideA[0]    = 0;
    strideB[0]    = 0;
    strideD[0]    = 0;
    strcpy(mess[0], "context from_tensor * weightQKV");

    // gemm 1
    M[1]          = max_input_len_padded;
    K[1]          = size_per_head;
    N[1]          = max_input_len_padded;
    batchCount[1] = batch_size * beam_width * local_head_num;
    strideA[1]    = max_input_len_padded * size_per_head;
    strideB[1]    = max_input_len_padded * size_per_head;
    strideD[1]    = max_input_len_padded * max_input_len_padded;
    strcpy(mess[1], "context batch gemm Q*K^T");

    // gemm 2
    M[2]          = max_input_len_padded;
    K[2]          = max_input_len_padded;
    N[2]          = size_per_head;
    batchCount[2] = batch_size * beam_width * local_head_num;
    strideA[2]    = max_input_len_padded * size_per_head;
    strideB[2]    = max_input_len_padded * max_input_len_padded;
    strideD[2]    = max_input_len_padded * size_per_head;
    strcpy(mess[2], "context batch gemm QK*V^T");

    // gemm 3
    M[3]          = batch_size * beam_width * max_input_len;
    K[3]          = local_hidden_units;
    N[3]          = hidden_units;
    batchCount[3] = 1;
    strideA[3]    = 0;
    strideB[3]    = 0;
    strideD[3]    = 0;
    strcpy(mess[3], "context attr * output_kernel");

    // gemm 4
    M[4]          = batch_size * beam_width * max_input_len;
    K[4]          = hidden_units;
    N[4]          = inter_size / tensor_para_size;
    batchCount[4] = 1;
    strideA[4]    = 0;
    strideB[4]    = 0;
    strideD[4]    = 0;
    strcpy(mess[4], "context ffn gemm 1");

    // gemm 5
    M[5]          = batch_size * beam_width * max_input_len;
    K[5]          = inter_size / tensor_para_size;
    N[5]          = hidden_units;
    batchCount[5] = 1;
    strideA[5]    = 0;
    strideB[5]    = 0;
    strideD[5]    = 0;
    strcpy(mess[5], "context ffn gemm 2");

    // gemm 6
    M[6]          = batch_size * beam_width;
    K[6]          = hidden_units;
    N[6]          = 3 * local_hidden_units;
    batchCount[6] = 1;
    strideA[6]    = 0;
    strideB[6]    = 0;
    strideD[6]    = 0;
    strcpy(mess[6], "from_tensor * weightQKV");

    // gemm 7
    M[7]          = batch_size * beam_width;
    K[7]          = local_hidden_units;
    N[7]          = hidden_units;
    batchCount[7] = 1;
    strideA[7]    = 0;
    strideB[7]    = 0;
    strideD[7]    = 0;
    strcpy(mess[7], "attr * output_kernel");

    // gemm 8
    M[8]          = batch_size * beam_width;
    K[8]          = hidden_units;
    N[8]          = inter_size / tensor_para_size;
    batchCount[8] = 1;
    strideA[8]    = 0;
    strideB[8]    = 0;
    strideD[8]    = 0;
    strcpy(mess[8], "ffn gemm 1");

    // gemm 9
    M[9]          = batch_size * beam_width;
    K[9]          = inter_size / tensor_para_size;
    N[9]          = hidden_units;
    batchCount[9] = 1;
    strideA[9]    = 0;
    strideB[9]    = 0;
    strideD[9]    = 0;
    strcpy(mess[9], "ffn gemm 2");

    // gemm 10
    M[10]          = batch_size * beam_width;
    K[10]          = hidden_units;
    N[10]          = ceil(vocab_size / 8.) * 8 / tensor_para_size;
    batchCount[10] = 1;
    strideA[10]    = 0;
    strideB[10]    = 0;
    strideD[10]    = 0;
    strcpy(mess[10], "logits gemm");

    cublasHandle_t cublas_handle;
    check_cuda_error(cublasCreate(&cublas_handle));
    cublasLtHandle_t ltHandle;
    check_cuda_error(cublasLtCreate(&ltHandle));

    cudaDataType_t AType;
    cudaDataType_t BType;
    cudaDataType_t CType;
    cudaDataType_t DType;
    cudaDataType_t DType_FP8[gemm_num];
    cudaDataType_t computeType;
    int            startAlgo, endAlgo;
    const int      ites = 100;
    struct timeval start, end;

    CublasDataType data_type;
    if (std::is_same<T, float>::value) {
        data_type   = FLOAT_DATATYPE;
        AType       = CUDA_R_32F;
        BType       = CUDA_R_32F;
        CType       = CUDA_R_32F;
        DType       = CUDA_R_32F;
        computeType = CUDA_R_32F;
        startAlgo   = (int)CUBLAS_GEMM_DEFAULT;
        endAlgo     = (int)CUBLAS_GEMM_ALGO23;
    }
    else if (std::is_same<T, half>::value) {
        data_type   = HALF_DATATYPE;
        AType       = CUDA_R_16F;
        BType       = CUDA_R_16F;
        CType       = CUDA_R_16F;
        DType       = CUDA_R_16F;
        computeType = CUDA_R_32F;
        startAlgo   = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
        endAlgo     = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
    }
#ifdef ENABLE_BF16
    else if (std::is_same<T, __nv_bfloat16>::value) {
        data_type   = BFLOAT16_DATATYPE;
        AType       = CUDA_R_16BF;
        BType       = CUDA_R_16BF;
        CType       = CUDA_R_16BF;
        DType       = CUDA_R_16BF;
        computeType = CUDA_R_32F;
        startAlgo   = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
        endAlgo     = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
    }
#endif
#ifdef ENABLE_FP8
    else if (std::is_same<T, __nv_fp8_e4m3>::value) {
        data_type = FP8_DATATYPE;
        AType     = CUDA_R_8F_E4M3;
        BType     = CUDA_R_8F_E4M3;
        CType     = CUDA_R_16BF;
#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE
        DType = CUDA_R_16BF
#else
        DType_FP8[0] = CUDA_R_8F_E4M3;
        DType_FP8[1] = CUDA_R_16BF;
        DType_FP8[2] = CUDA_R_8F_E4M3;
        DType_FP8[3] = CUDA_R_16BF;
        DType_FP8[4] = CUDA_R_16BF;
        DType_FP8[5] = CUDA_R_16BF;
#ifdef FP8_MHA
        DType_FP8[6] = CUDA_R_8F_E4M3;
#else
        DType_FP8[6] = CUDA_R_16BF;
#endif
        DType_FP8[7] = CUDA_R_16BF;
        DType_FP8[8] = CUDA_R_16BF;
        DType_FP8[9] = CUDA_R_16BF;
#endif
            computeType = CUDA_R_32F;
        startAlgo       = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
        endAlgo         = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
    }
#endif
    float alpha = (float)1.0f;
    float beta  = (float)0.0f;

    printf("***Encoder Gemm Testing Begin***\n");
    printf("***Cublas Gemm Testing Begin***\n");
    if (line_count == 0) {
        fprintf(fd,
                "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, n, m, k, algoId, "
                "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
                "inner_shapeId, cluster_shapeId, "
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
                "mma_shapeId, cga_shapeId, schedule_mode, "
#endif
                "exec_time\n");
    }

    for (int i = 0; i < gemm_num; ++i) {
Li Zhang's avatar
Li Zhang committed
317
318
        // tuning of context gemm and logits gemm is not working yet
        if (i <= 5 || i == 10) {
Li Zhang's avatar
Li Zhang committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            continue;
        }
        int seq_len = i <= 5 ? max_input_len : 1;

        int m = M[i], n = N[i], k = K[i];
        printf("\n-----------------------------\n");
        printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]);
        T* d_A = (T*)buffer;
        T* d_B = d_A + m * k * batchCount[i];
        T* d_C = d_B + k * n * batchCount[i];

        float exec_time = 99999.0f;
        int   fast_algo = 0;
        for (int algo = startAlgo; algo <= endAlgo; algo++) {
            cublasStatus_t status;
            cudaDeviceSynchronize();
            gettimeofday(&start, NULL);
            for (int ite = 0; ite < ites; ++ite) {
                if (i == 1) {
                    status = cublasGemmStridedBatchedEx(cublas_handle,
                                                        CUBLAS_OP_T,
                                                        CUBLAS_OP_N,
                                                        max_input_len,
                                                        max_input_len,
                                                        size_per_head,
                                                        &alpha,
                                                        d_B,
                                                        BType,
                                                        size_per_head,
                                                        max_input_len * size_per_head,
                                                        d_A,
                                                        AType,
                                                        size_per_head,
                                                        max_input_len * size_per_head,
                                                        &beta,
                                                        d_C,
                                                        CUDA_R_32F,  // CType,
                                                        max_input_len,
                                                        max_input_len * max_input_len,
                                                        batchCount[i],
                                                        computeType,
                                                        static_cast<cublasGemmAlgo_t>(algo));
                }
                else if (i == 2) {
                    status = cublasGemmStridedBatchedEx(cublas_handle,
                                                        CUBLAS_OP_N,
                                                        CUBLAS_OP_N,
                                                        size_per_head,
                                                        max_input_len,
                                                        max_input_len,
                                                        &alpha,
                                                        d_B,
                                                        BType,
                                                        size_per_head,
                                                        max_input_len * size_per_head,
                                                        d_A,
                                                        AType,
                                                        max_input_len,
                                                        max_input_len * max_input_len,
                                                        &beta,
                                                        d_C,
                                                        CType,
                                                        size_per_head,
                                                        max_input_len * size_per_head,
                                                        batchCount[i],
                                                        computeType,
                                                        static_cast<cublasGemmAlgo_t>(algo));
                }
                else if (i == 10) {
                    status = cublasGemmEx(cublas_handle,
                                          CUBLAS_OP_T,
                                          CUBLAS_OP_N,
                                          n,
                                          m,
                                          k,
                                          &alpha,
                                          d_B,
                                          BType,
                                          k,
                                          d_A,
                                          AType,
                                          k,
                                          &beta,
                                          d_C,
                                          CType,
                                          n,
                                          computeType,
                                          static_cast<cublasGemmAlgo_t>(algo));
                }
                else {
                    status = cublasGemmEx(cublas_handle,
                                          CUBLAS_OP_N,
                                          CUBLAS_OP_N,
                                          n,
                                          m,
                                          k,
                                          &alpha,
                                          d_B,
                                          BType,
                                          n,
                                          d_A,
                                          AType,
                                          k,
                                          &beta,
                                          d_C,
                                          CType,
                                          n,
                                          computeType,
                                          static_cast<cublasGemmAlgo_t>(algo));
                }

                if (status != CUBLAS_STATUS_SUCCESS) {
                    break;
                }
            }
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
            if (status == CUBLAS_STATUS_SUCCESS) {
                printf("algo_%d costs %.3fms \n", algo, diffTime(start, end) / ites);
                if (diffTime(start, end) / ites < exec_time) {
                    exec_time = diffTime(start, end) / ites;
                    fast_algo = algo;
                }
            }
            sync_check_cuda_error();
        }

        printf("fast_algo %d costs %.3f ms\n", fast_algo, exec_time);

        // for fp16 and bf16, we compare cublasLt
        // for fp8, compare cublaslt for all gemm kernels
        if ((data_type != FLOAT_DATATYPE && i != 1 && i != 2 && i != 10) || data_type == FP8_DATATYPE) {
            printf("***cublasLt Gemm Testing Beign***\n");
            // Let try a fixed number of combinations
Li Zhang's avatar
Li Zhang committed
453
            int                ALGO_COMBINATIONS = 10000;
Li Zhang's avatar
Li Zhang committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            customMatmulPerf_t perfResults[ALGO_COMBINATIONS];

            // for gpt, computeType & scaleType should be FP32
            LtHgemmCustomFind<T, float>(ltHandle,
                                        batch_size * beam_width,
                                        i == 1 || i == 2 ? max_input_len : 1,
                                        head_num,
                                        size_per_head,
                                        n,
                                        m,
                                        k,
                                        &alpha,
                                        d_B,
                                        d_A,
                                        &beta,
                                        d_C,
                                        cublas_workspace,
                                        workSpaceSize,
                                        fd,
                                        perfResults,
                                        ALGO_COMBINATIONS,
                                        DType_FP8[i],
                                        batchCount[i],
                                        strideA[i],
                                        strideB[i],
                                        strideD[i]);
            if (perfResults[0].time < exec_time) {
                printPerfStructure(batch_size * beam_width,
                                   seq_len,
                                   head_num,
                                   size_per_head,
                                   n,
                                   m,
                                   k,
                                   perfResults[0],
                                   fd,
                                   data_type,
                                   0,
                                   batchCount[i]);
            }
            else {
                fprintf(fd,
                        "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
                        "-1 -1 "
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
                        "-1 -1 -1 "
#endif
                        "%f\n",
                        batch_size * beam_width,
                        seq_len,
                        head_num,
                        size_per_head,
                        data_type,
                        batchCount[i],
                        n,
                        m,
                        k,
                        fast_algo,
                        exec_time);
            }
            printf("***cublasLt Gemm Testing End***\n");
        }
        else {
            fprintf(fd,
                    "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
                    "-1 -1 "
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
                    "-1 -1 -1 "
#endif
                    "%f\n",
                    batch_size * beam_width,
                    seq_len,
                    head_num,
                    size_per_head,
                    data_type,
                    batchCount[i],
                    n,
                    m,
                    k,
                    fast_algo,
                    exec_time);
        }
        sync_check_cuda_error();
        exec_times[i] = exec_time;
    }
    printf("***cublas Gemm Testing End***\n\n");
    fclose(fd);

#ifdef SPARSITY_ENABLED
    bool do_sparse_test = false;
    if (prop.major == 8 && (prop.minor == 0 || prop.minor == 6) && sizeof(T) == sizeof(half)) {
        do_sparse_test = true;
    }
    if (do_sparse_test) {
        printf("***cusparseLt Gemm Testing Begin***\n");
        // Only first 8 cases can be sparse
        // - QKV kernel, Projection, FC1, FC2 in context or decoding.
        const int spgemm_num = 8;
        if (!isAppend) {
            fd = fopen(SPGEMM_CONFIG, "w+");
        }
        else {
            fd = fopen(SPGEMM_CONFIG, "a+");
            std::vector<std::string> config;
            char                     line[1024];
            while (fgets(line, 1024, fd) != NULL) {
                config.push_back(std::string(line));
            }
            line_count = config.size();
            // gemm_num configs (cublas/cublasLt), first row is not included
            if (config.size() >= (MAX_CONFIG_NUM * spgemm_num + 1)) {
                int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * spgemm_num);
                fclose(fd);
                fd = fopen(SPGEMM_CONFIG, "w+");
                fprintf(fd, "%s", config[0].c_str());
                for (uint i = startIdx; i < config.size(); i++) {
                    fprintf(fd, "%s", config[i].c_str());
                }
                line_count = config.size() - (spgemm_num + 3);
            }
        }
        if (line_count == 0) {
            // header line
            fprintf(fd,
                    "batch_size, seq_len, head_num, size_per_head dataType "
                    "### batchCount, m, n, k, algoId, exec_time\n");
        }

        cusparseLtHandle_t handle;
        CHECK_CUSPARSE(cusparseLtInit(&handle));
        cusparseOrder_t     order = CUSPARSE_ORDER_COL;
        cusparseOperation_t opA   = CUSPARSE_OPERATION_NON_TRANSPOSE;
        cusparseOperation_t opB   = CUSPARSE_OPERATION_NON_TRANSPOSE;
        // let's make this optional
        cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F;
        unsigned            alignment    = 16;
        cudaStream_t        stream       = 0;
        float               alpha2       = 1.0f;
        float               beta2        = 0.0f;
        for (int i = 0; i < gemm_num; ++i) {
            // skip qk or attn or logit gemms.
            if (i == 1 || i == 2 || i == 10) {
                continue;
            }

            // seq_len is always 1 except context gemms.
            int seq_len = i <= 5 ? max_input_len : 1;

            // to be compatible with spgemm wrapper, we let A be the weight matrix
            // so m and n are swapped
            // A: mxk B: kxn C:mxn
            int m = N[i], n = M[i], k = K[i];
            printf("\n-----------------------------\n");
            printf("GEMM test %d: [M: %d, K: %d, N: %d]\n", i, m, k, n);

            if (n % 8 != 0) {
                n = div_up(n, 8) * 8;  // pad n to be multiple of 8 as FT does.
            }

            T* d_A = (T*)buffer;
            T* d_B = d_A + m * k * batchCount[i];
            T* d_C = d_B + k * n * batchCount[i];
            T* dA_compressed;
            {
AllentDan's avatar
AllentDan committed
620
                cusparseLtMatDescriptor_t mat_A;
Li Zhang's avatar
Li Zhang committed
621
                CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
AllentDan's avatar
AllentDan committed
622
                    &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
Li Zhang's avatar
Li Zhang committed
623
                CHECK_CUSPARSE(
AllentDan's avatar
AllentDan committed
624
                    cusparseLtSpMMAPrune2(&handle, &mat_A, true, opA, d_A, d_A, CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
Li Zhang's avatar
Li Zhang committed
625
                size_t compressed_size;
AllentDan's avatar
AllentDan committed
626
                CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&handle, &mat_A, &compressed_size))
Li Zhang's avatar
Li Zhang committed
627
                check_cuda_error(cudaMalloc((void**)&dA_compressed, compressed_size));
AllentDan's avatar
AllentDan committed
628
                CHECK_CUSPARSE(cusparseLtSpMMACompress2(&handle, &mat_A, true, opA, d_A, dA_compressed, stream))
Li Zhang's avatar
Li Zhang committed
629
630
631
632
633
634
635
            }

            float exec_time = 99999.0f;
            int   fast_algo = 0;
            if (isSparseGemmAvailable(m, n, k)) {
                for (int alg = 0; alg < 4; ++alg) {
                    cudaDeviceSynchronize();
AllentDan's avatar
AllentDan committed
636
                    cusparseLtMatDescriptor_t mat_A, mat_B, mat_C;
Li Zhang's avatar
Li Zhang committed
637
638
639
640
                    void*                     d_workspace = nullptr;
                    int                       num_streams = 1;
                    cudaStream_t              streams[1]  = {stream};
                    CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit(
AllentDan's avatar
AllentDan committed
641
642
643
644
                        &handle, &mat_A, m, k, m, alignment, CUDA_R_16F, order, CUSPARSELT_SPARSITY_50_PERCENT))
                    CHECK_CUSPARSE(cusparseLtDenseDescriptorInit(&handle, &mat_B, k, n, k, alignment, CUDA_R_16F, order))
                    CHECK_CUSPARSE(
                        cusparseLtDenseDescriptorInit(&handle, &mat_C, m, n, m, alignment, CUDA_R_16F, order))
Li Zhang's avatar
Li Zhang committed
645
646
647
648
649
650
651
652
653
654
                    cudaDeviceSynchronize();
                    gettimeofday(&start, NULL);
                    for (int ite = 0; ite < ites; ++ite) {
                        // initializing MatDesc takes a lot of time
                        // and these descs can be stored to other place
                        // whereas storing MatMulPlan to other place will cause errors
                        cusparseLtMatmulDescriptor_t   matmul;
                        cusparseLtMatmulAlgSelection_t alg_sel;
                        cusparseLtMatmulPlan_t         plan;
                        CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
AllentDan's avatar
AllentDan committed
655
                            &handle, &matmul, opA, opB, &mat_A, &mat_B, &mat_C, &mat_C, compute_type))
Li Zhang's avatar
Li Zhang committed
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
                        CHECK_CUSPARSE(
                            cusparseLtMatmulAlgSelectionInit(&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
                        CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
                            &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)))
                        size_t workspace_size;
                        CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&handle, &alg_sel, &workspace_size))
                        CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel, workspace_size))
                        CHECK_CUSPARSE(cusparseLtMatmul(&handle,
                                                        &plan,
                                                        &alpha2,
                                                        dA_compressed,
                                                        d_B,
                                                        &beta2,
                                                        d_C,
                                                        d_C,
                                                        d_workspace,
                                                        streams,
                                                        num_streams))
                        CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan))
                    }
                    cudaDeviceSynchronize();
                    gettimeofday(&end, NULL);
                    printf("algo_%d costs %.3fms \n", alg, diffTime(start, end) / ites);
                    if (diffTime(start, end) < exec_time) {
                        exec_time = diffTime(start, end);
                        fast_algo = alg;
                    }
                }
            }
            exec_time /= ites;
            if (exec_time >= exec_times[i]) {
                fast_algo = -1;
            }
            printf("fast_algo %d\n", fast_algo);
            fprintf(fd,
                    "%d %d %d %d %d ### %d %d %d %d %d %f\n",
                    batch_size * beam_width,
                    seq_len,
                    head_num,
                    size_per_head,
                    data_type,
                    batchCount[i],
                    m,
                    n,
                    k,
                    fast_algo,
                    exec_time);
            cudaFree(dA_compressed);
        }
        CHECK_CUSPARSE(cusparseLtDestroy(&handle))
        fclose(fd);
        printf("***cusparseLt Gemm Testing End***\n");
    }
#endif

    printf("***GPT Gemm Testing End***\n");
    return;
}

template void generate_gpt_gemm_config<float>(int   batch_size,
                                              int   beam_width,
                                              int   max_input_len,
                                              int   head_num,
                                              int   size_per_head,
                                              int   inter_size,
                                              int   vocab_size,
                                              int   tensor_para_size,
                                              void* buffer_in,
                                              bool  isAppend);

template void generate_gpt_gemm_config<half>(int   batch_size,
                                             int   beam_width,
                                             int   max_input_len,
                                             int   head_num,
                                             int   size_per_head,
                                             int   inter_size,
                                             int   vocab_size,
                                             int   tensor_para_size,
                                             void* buffer_in,
                                             bool  isAppend);

#ifdef ENABLE_BF16
template void generate_gpt_gemm_config<__nv_bfloat16>(int   batch_size,
                                                      int   beam_width,
                                                      int   max_input_len,
                                                      int   head_num,
                                                      int   size_per_head,
                                                      int   inter_size,
                                                      int   vocab_size,
                                                      int   tensor_para_size,
                                                      void* buffer_in,
                                                      bool  isAppend);
#endif

#ifdef ENABLE_FP8
template void generate_gpt_gemm_config<__nv_fp8_e4m3>(int   batch_size,
                                                      int   beam_width,
                                                      int   max_input_len,
                                                      int   head_num,
                                                      int   size_per_head,
                                                      int   inter_size,
                                                      int   vocab_size,
                                                      int   tensor_para_size,
                                                      void* buffer_in,
                                                      bool  isAppend);
#endif

size_t calGptGemmTestBufSizeInByte(int            batch_size,
                                   int            beam_width,
                                   int            max_input_len,
                                   int            head_num,
                                   int            size_per_head,
                                   int            inter_size,
                                   int            vocab_size,
                                   int            tensor_para_size,
                                   CublasDataType data_type)
{
    size_t       buf_size_in_byte   = 0;
    const size_t hidden_units       = head_num * size_per_head;
    const size_t local_head_num     = head_num / tensor_para_size;
    const size_t local_hidden_units = local_head_num * size_per_head;

    // int wordSize = (data_type == FLOAT_DATATYPE ? sizeof(float) : sizeof(half));
    // Because we always use float for some buffer, set the wordSize to float directly.
    int wordSize = sizeof(float);

    size_t              m = batch_size * beam_width * max_input_len;
    std::vector<size_t> buff_size;
    // for context qkv gemm
    buff_size.push_back(m * hidden_units + hidden_units * 3 * local_hidden_units + m * 3 * local_hidden_units);
    // for context batch gemm
    buff_size.push_back(m * local_hidden_units + m * local_hidden_units
                        + batch_size * beam_width * head_num * max_input_len * max_input_len);
    // for context ffn gemm
    buff_size.push_back(m * inter_size / tensor_para_size + hidden_units * inter_size / tensor_para_size
                        + m * hidden_units);
    // for vocab
    buff_size.push_back(m * hidden_units + hidden_units * ceil(vocab_size / 8.) * 8 / tensor_para_size
                        + m * ceil(vocab_size / 8.) * 8 / tensor_para_size);

    for (auto t : buff_size) {
        buf_size_in_byte = buf_size_in_byte > t ? buf_size_in_byte : t;
    }
    buf_size_in_byte *= wordSize;
    buf_size_in_byte += ((data_type == HALF_DATATYPE || data_type == BFLOAT16_DATATYPE || data_type == FP8_DATATYPE) ?
                             CUBLAS_WORKSPACE_SIZE :
                             0);

    return buf_size_in_byte;
}

lvhan028's avatar
lvhan028 committed
807
}  // namespace turbomind