profile_gemm.cpp 7.05 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
8
9

#include "profiler/include/profile_gemm_impl.hpp"
Chao Liu's avatar
Chao Liu committed
10

Chao Liu's avatar
Chao Liu committed
11
enum struct GemmMatrixLayout
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
19
enum struct GemmDataType
Chao Liu's avatar
Chao Liu committed
20
{
21
22
23
24
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
Chao Liu's avatar
Chao Liu committed
25
26
};

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
static void print_helper_msg()
{
    std::cout << "arg1: tensor operation (gemm: GEMM)\n"
              << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
              << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
              << "                     1: A[m, k] * B[n, k] = C[m, n];\n"
              << "                     2: A[k, m] * B[k, n] = C[m, n];\n"
              << "                     3: A[k, m] * B[n, k] = C[m, n])\n"
              << "arg4: verification (0: no; 1: yes)\n"
              << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
              << "arg6: print tensor value (0: no; 1: yes)\n"
              << "arg7: time kernel (0: no, 1: yes)\n"
              << "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"
              << std::endl;
}

Chao Liu's avatar
Chao Liu committed
43
44
int profile_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
45
    if(argc != 14)
Chao Liu's avatar
Chao Liu committed
46
    {
47
        print_helper_msg();
Chao Liu's avatar
Chao Liu committed
48
49
50
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
51
52
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
53
54
55
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
56
    const bool time_kernel     = std::stoi(argv[7]);
Chao Liu's avatar
Chao Liu committed
57
58
59
60
61
62
63
64
65

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

Chao Liu's avatar
Chao Liu committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    using F32   = float;
    using F16   = ck::half_t;
    using BF16  = ck::bhalf_t;
    using INT8  = int8_t;
    using INT32 = int32_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto acc_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType   = decltype(a_type);
        using BDataType   = decltype(b_type);
        using AccDataType = decltype(acc_type);
        using CDataType   = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

        bool pass =
            ck::profiler::profile_gemm_impl<ADataType,
                                            BDataType,
                                            AccDataType,
                                            CDataType,
                                            ALayout,
                                            BLayout,
                                            CLayout>(do_verification,
                                                     init_method,
                                                     do_log,
                                                     time_kernel,
                                                     M,
                                                     N,
                                                     K,
                                                     (StrideA < 0) ? DefaultStrideA : StrideA,
                                                     (StrideB < 0) ? DefaultStrideB : StrideB,
                                                     (StrideC < 0) ? DefaultStrideC : StrideC);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
117
    {
118
        return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
119
    }
Chao Liu's avatar
Chao Liu committed
120
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
121
    {
122
        return profile(Row{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
123
    }
Chao Liu's avatar
Chao Liu committed
124
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
125
    {
126
        return profile(Col{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
127
    }
Chao Liu's avatar
Chao Liu committed
128
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
129
    {
130
        return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
131
    }
Chao Liu's avatar
Chao Liu committed
132
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
133
    {
134
        return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
135
    }
Chao Liu's avatar
Chao Liu committed
136
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
137
    {
138
        return profile(Row{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
139
    }
Chao Liu's avatar
Chao Liu committed
140
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
141
    {
142
        return profile(Col{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
143
    }
Chao Liu's avatar
Chao Liu committed
144
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
145
    {
146
        return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
147
    }
Chao Liu's avatar
Chao Liu committed
148
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
149
    {
150
        return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
151
    }
Chao Liu's avatar
Chao Liu committed
152
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
153
    {
154
        return profile(Row{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
155
    }
Chao Liu's avatar
Chao Liu committed
156
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
157
    {
158
        return profile(Col{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
159
    }
Chao Liu's avatar
Chao Liu committed
160
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
161
    {
162
        return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
163
    }
Chao Liu's avatar
Chao Liu committed
164
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
165
    {
166
        return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
167
    }
Chao Liu's avatar
Chao Liu committed
168
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
169
    {
170
        return profile(Row{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
171
    }
Chao Liu's avatar
Chao Liu committed
172
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
173
    {
174
        return profile(Col{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
175
    }
Chao Liu's avatar
Chao Liu committed
176
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
177
    {
178
        return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
179
    }
Chao Liu's avatar
Chao Liu committed
180
181
    else
    {
Chao Liu's avatar
Chao Liu committed
182
        std::cout << "this data_type & layout is not implemented" << std::endl;
Chao Liu's avatar
Chao Liu committed
183

Chao Liu's avatar
Chao Liu committed
184
185
        return 1;
    }
Chao Liu's avatar
Chao Liu committed
186
}