"test/vscode:/vscode.git/clone" did not exist on "812a2dc684f5540268381cde19c58598013e90ee"
profile_gemm.cpp 7.05 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
8
9

#include "profiler/include/profile_gemm_impl.hpp"
Chao Liu's avatar
Chao Liu committed
10

Chao Liu's avatar
Chao Liu committed
11
enum struct GemmMatrixLayout
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
19
enum struct GemmDataType
Chao Liu's avatar
Chao Liu committed
20
{
21
22
23
24
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
Chao Liu's avatar
Chao Liu committed
25
26
};

Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
static void print_helper_msg()
{
    std::cout << "arg1: tensor operation (gemm: GEMM)\n"
              << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
              << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
              << "                     1: A[m, k] * B[n, k] = C[m, n];\n"
              << "                     2: A[k, m] * B[k, n] = C[m, n];\n"
              << "                     3: A[k, m] * B[n, k] = C[m, n])\n"
              << "arg4: verification (0: no; 1: yes)\n"
              << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
              << "arg6: print tensor value (0: no; 1: yes)\n"
              << "arg7: time kernel (0: no, 1: yes)\n"
              << "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"
              << std::endl;
}

Chao Liu's avatar
Chao Liu committed
43
44
int profile_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
45
    if(argc != 14)
Chao Liu's avatar
Chao Liu committed
46
    {
Chao Liu's avatar
Chao Liu committed
47
        print_helper_msg();
Chao Liu's avatar
Chao Liu committed
48
49
50
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
51
52
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
53
54
55
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
56
    const bool time_kernel     = std::stoi(argv[7]);
Chao Liu's avatar
Chao Liu committed
57
58
59
60
61
62
63
64
65

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

Chao Liu's avatar
Chao Liu committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    using F32   = float;
    using F16   = ck::half_t;
    using BF16  = ck::bhalf_t;
    using INT8  = int8_t;
    using INT32 = int32_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto acc_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType   = decltype(a_type);
        using BDataType   = decltype(b_type);
        using AccDataType = decltype(acc_type);
        using CDataType   = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

        bool pass =
            ck::profiler::profile_gemm_impl<ADataType,
                                            BDataType,
                                            AccDataType,
                                            CDataType,
                                            ALayout,
                                            BLayout,
                                            CLayout>(do_verification,
                                                     init_method,
                                                     do_log,
                                                     time_kernel,
                                                     M,
                                                     N,
                                                     K,
                                                     (StrideA < 0) ? DefaultStrideA : StrideA,
                                                     (StrideB < 0) ? DefaultStrideB : StrideB,
                                                     (StrideC < 0) ? DefaultStrideC : StrideC);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
117
    {
Chao Liu's avatar
Chao Liu committed
118
        return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
119
    }
Chao Liu's avatar
Chao Liu committed
120
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
121
    {
Chao Liu's avatar
Chao Liu committed
122
        return profile(Row{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
123
    }
Chao Liu's avatar
Chao Liu committed
124
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
125
    {
Chao Liu's avatar
Chao Liu committed
126
        return profile(Col{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
127
    }
Chao Liu's avatar
Chao Liu committed
128
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
129
    {
Chao Liu's avatar
Chao Liu committed
130
        return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
Chao Liu's avatar
Chao Liu committed
131
    }
Chao Liu's avatar
Chao Liu committed
132
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
133
    {
Chao Liu's avatar
Chao Liu committed
134
        return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
135
    }
Chao Liu's avatar
Chao Liu committed
136
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
137
    {
Chao Liu's avatar
Chao Liu committed
138
        return profile(Row{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
139
    }
Chao Liu's avatar
Chao Liu committed
140
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
141
    {
Chao Liu's avatar
Chao Liu committed
142
        return profile(Col{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
143
    }
Chao Liu's avatar
Chao Liu committed
144
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
145
    {
Chao Liu's avatar
Chao Liu committed
146
        return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
Chao Liu's avatar
Chao Liu committed
147
    }
Chao Liu's avatar
Chao Liu committed
148
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
149
    {
Chao Liu's avatar
Chao Liu committed
150
        return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
151
    }
Chao Liu's avatar
Chao Liu committed
152
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
153
    {
Chao Liu's avatar
Chao Liu committed
154
        return profile(Row{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
155
    }
Chao Liu's avatar
Chao Liu committed
156
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
157
    {
Chao Liu's avatar
Chao Liu committed
158
        return profile(Col{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
159
    }
Chao Liu's avatar
Chao Liu committed
160
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
161
    {
Chao Liu's avatar
Chao Liu committed
162
        return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
163
    }
Chao Liu's avatar
Chao Liu committed
164
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
165
    {
Chao Liu's avatar
Chao Liu committed
166
        return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
167
    }
Chao Liu's avatar
Chao Liu committed
168
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
169
    {
Chao Liu's avatar
Chao Liu committed
170
        return profile(Row{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
171
    }
Chao Liu's avatar
Chao Liu committed
172
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
173
    {
Chao Liu's avatar
Chao Liu committed
174
        return profile(Col{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
175
    }
Chao Liu's avatar
Chao Liu committed
176
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
177
    {
Chao Liu's avatar
Chao Liu committed
178
        return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
179
    }
Chao Liu's avatar
Chao Liu committed
180
181
    else
    {
Chao Liu's avatar
Chao Liu committed
182
        std::cout << "this data_type & layout is not implemented" << std::endl;
Chao Liu's avatar
Chao Liu committed
183

Chao Liu's avatar
Chao Liu committed
184
185
        return 1;
    }
Chao Liu's avatar
Chao Liu committed
186
}