profile_gemm.cpp 5.94 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
Chao Liu's avatar
Chao Liu committed
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "profile_gemm_impl.hpp"

Chao Liu's avatar
Chao Liu committed
10
11
// return true if test pass
bool profile_gemm(int argc, char* argv[])
Chao Liu's avatar
Chao Liu committed
12
{
Chao Liu's avatar
Chao Liu committed
13
14
15
16
17
18
19
    enum struct GemmMatrixLayout
    {
        MK_KN_MN, // 0
        MK_NK_MN, // 1
        KM_KN_MN, // 2
        KM_NK_MN, // 3
    };
Chao Liu's avatar
Chao Liu committed
20

Chao Liu's avatar
Chao Liu committed
21
22
23
24
25
26
27
    enum struct GemmDataType
    {
        F32_F32_F32,    // 0
        F16_F16_F16,    // 1
        BF16_BF16_BF16, // 2
        INT8_INT8_INT8, // 3
    };
Chao Liu's avatar
Chao Liu committed
28

Chao Liu's avatar
Chao Liu committed
29
    if(argc != 14)
Chao Liu's avatar
Chao Liu committed
30
31
    {
        printf("arg1: tensor operation (gemm: GEMM)\n");
32
        printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
Chao Liu's avatar
Chao Liu committed
33
34
        printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
ltqin's avatar
ltqin committed
35
36
        printf("                     2: A[k, m] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, m] * B[n, k] = C[m, n])\n");
Chao Liu's avatar
Chao Liu committed
37
38
39
40
41
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
        printf("arg8: print tensor value (0: no; 1: yes)\n");
        printf("arg7: run kernel # of times (>1)\n");
        printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
Chao Liu's avatar
Chao Liu committed
42
        return false;
Chao Liu's avatar
Chao Liu committed
43
44
    }

Chao Liu's avatar
Chao Liu committed
45
46
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
47
48
49
50
51
52
53
54
55
56
57
58
59
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
    const int nrepeat          = std::stoi(argv[7]);

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

Chao Liu's avatar
clean  
Chao Liu committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    auto profile =
        [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) {
            using ADataType = decltype(a_type);
            using BDataType = decltype(b_type);
            using CDataType = decltype(c_type);
            using ALayout   = decltype(a_layout);
            using BLayout   = decltype(b_layout);
            using CLayout   = decltype(c_layout);

            return ck::profiler::
                profile_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
                    do_verification,
                    init_method,
                    do_log,
                    nrepeat,
                    M,
                    N,
                    K,
                    (StrideA < 0) ? K : StrideA,
                    (StrideB < 0) ? N : StrideB,
                    (StrideC < 0) ? N : StrideC);
        };

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

Chao Liu's avatar
Chao Liu committed
86
87
    if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
88
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
89
90
91
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
92
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
93
94
95
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
96
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
97
98
99
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
100
        return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
101
102
103
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
104
        return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
105
106
107
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
108
        return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
109
110
111
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
112
        return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
113
114
115
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
116
        return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
117
    }
118
119
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
120
        return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Row{}, Row{});
121
    }
122
123
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
124
        return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Col{}, Row{});
125
    }
126
127
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
128
        return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Row{}, Row{});
129
130
131
    }
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
132
        return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Col{}, Row{});
133
134
135
    }
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
136
        return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Row{}, Row{});
137
    }
138
139
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
140
        return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Col{}, Row{});
141
    }
142
143
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
144
        return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Row{}, Row{});
145
146
147
    }
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
    {
Chao Liu's avatar
clean  
Chao Liu committed
148
        return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Col{}, Row{});
149
    }
Chao Liu's avatar
Chao Liu committed
150
151
    else
    {
Chao Liu's avatar
Chao Liu committed
152
        std::cout << "this data_type & layout is not implemented" << std::endl;
Chao Liu's avatar
Chao Liu committed
153

Chao Liu's avatar
Chao Liu committed
154
155
        return true;
    }
Chao Liu's avatar
Chao Liu committed
156
}