profile_gemm_splitk.cpp 4.57 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_gemm_splitk_impl.hpp"

// return true if test pass
bool profile_gemm_splitk(int argc, char* argv[])
{
    enum struct GemmMatrixLayout
    {
        MK_KN_MN, // 0
        MK_NK_MN, // 1
        KM_KN_MN, // 2
        KM_NK_MN, // 3
    };

    enum struct GemmDataType
    {
        F32_F32_F32, // 0
        F16_F16_F16, // 1
    };

    if(argc != 15)
    {
Chao Liu's avatar
clean  
Chao Liu committed
28
        printf("arg1: tensor operation (gemm: GEMMSplitK)\n");
Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
        printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
        printf("                     2: A[k, m] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, m] * B[n, k] = C[m, n])\n");
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
        printf("arg8: print tensor value (0: no; 1: yes)\n");
        printf("arg7: run kernel # of times (>1)\n");
        printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
        printf("arg14: split k into mulitiple batch\n");
        return false;
    }

    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
    const int nrepeat          = std::stoi(argv[7]);

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);
    const int KBatch  = std::stoi(argv[14]);

Chao Liu's avatar
clean  
Chao Liu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    auto profile = [&](auto a_type,
                       auto b_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType = decltype(a_type);
        using BDataType = decltype(b_type);
        using CDataType = decltype(c_type);
        using ALayout   = decltype(a_layout);
        using BLayout   = decltype(b_layout);
        using CLayout   = decltype(c_layout);

        return ck::profiler::
            profile_gemm_splitk_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
                do_verification,
                init_method,
                do_log,
                nrepeat,
                M,
                N,
                K,
                (StrideA < 0) ? K : StrideA,
                (StrideB < 0) ? N : StrideB,
                (StrideC < 0) ? N : StrideC,
                KBatch);
    };

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

Chao Liu's avatar
Chao Liu committed
90
91
    if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
92
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
93
94
95
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
96
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
97
98
99
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
100
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
101
102
103
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
104
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
105
106
107
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
108
        return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
109
110
111
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
112
        return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
113
114
115
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
116
        return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
117
118
119
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
120
        return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
121
122
123
124
125
126
127
128
    }
    else
    {
        std::cout << "this data_type & layout is not implemented" << std::endl;

        return true;
    }
}