profile_gemm.cpp 6.96 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
8
9

#include "profiler/include/profile_gemm_impl.hpp"
Chao Liu's avatar
Chao Liu committed
10

Chao Liu's avatar
Chao Liu committed
11
enum struct GemmMatrixLayout
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
19
enum struct GemmDataType
Chao Liu's avatar
Chao Liu committed
20
{
21
22
23
24
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
Chao Liu's avatar
Chao Liu committed
25
26
27
28
};

int profile_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
29
    if(argc != 14)
Chao Liu's avatar
Chao Liu committed
30
31
    {
        printf("arg1: tensor operation (gemm: GEMM)\n");
32
        printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
Chao Liu's avatar
Chao Liu committed
33
34
        printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
ltqin's avatar
ltqin committed
35
36
        printf("                     2: A[k, m] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, m] * B[n, k] = C[m, n])\n");
Chao Liu's avatar
Chao Liu committed
37
38
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
JD's avatar
JD committed
39
        printf("arg6: print tensor value (0: no; 1: yes)\n");
Chao Liu's avatar
Chao Liu committed
40
        printf("arg7: time kernel (0=no, 1=yes)\n");
Chao Liu's avatar
Chao Liu committed
41
42
43
44
        printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
45
46
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
47
48
49
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
50
    const bool time_kernel     = std::stoi(argv[7]);
Chao Liu's avatar
Chao Liu committed
51
52
53
54
55
56
57
58
59

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

Chao Liu's avatar
Chao Liu committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    using F32   = float;
    using F16   = ck::half_t;
    using BF16  = ck::bhalf_t;
    using INT8  = int8_t;
    using INT32 = int32_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto acc_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType   = decltype(a_type);
        using BDataType   = decltype(b_type);
        using AccDataType = decltype(acc_type);
        using CDataType   = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

        bool pass =
            ck::profiler::profile_gemm_impl<ADataType,
                                            BDataType,
                                            AccDataType,
                                            CDataType,
                                            ALayout,
                                            BLayout,
                                            CLayout>(do_verification,
                                                     init_method,
                                                     do_log,
                                                     time_kernel,
                                                     M,
                                                     N,
                                                     K,
                                                     (StrideA < 0) ? DefaultStrideA : StrideA,
                                                     (StrideB < 0) ? DefaultStrideB : StrideB,
                                                     (StrideC < 0) ? DefaultStrideC : StrideC);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
111
    {
Chao Liu's avatar
Chao Liu committed
112
        return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
113
    }
Chao Liu's avatar
Chao Liu committed
114
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
115
    {
Chao Liu's avatar
Chao Liu committed
116
        return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
117
    }
Chao Liu's avatar
Chao Liu committed
118
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
119
    {
Chao Liu's avatar
Chao Liu committed
120
        return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
121
    }
Chao Liu's avatar
Chao Liu committed
122
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
123
    {
Chao Liu's avatar
Chao Liu committed
124
        return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
125
    }
Chao Liu's avatar
Chao Liu committed
126
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Chao Liu's avatar
Chao Liu committed
127
    {
Chao Liu's avatar
Chao Liu committed
128
        return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
129
    }
Chao Liu's avatar
Chao Liu committed
130
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Chao Liu's avatar
Chao Liu committed
131
    {
Chao Liu's avatar
Chao Liu committed
132
        return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
133
    }
Chao Liu's avatar
Chao Liu committed
134
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Chao Liu's avatar
Chao Liu committed
135
    {
Chao Liu's avatar
Chao Liu committed
136
        return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
Chao Liu's avatar
Chao Liu committed
137
    }
Chao Liu's avatar
Chao Liu committed
138
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Chao Liu's avatar
Chao Liu committed
139
    {
Chao Liu's avatar
Chao Liu committed
140
        return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
Chao Liu's avatar
Chao Liu committed
141
    }
Chao Liu's avatar
Chao Liu committed
142
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
143
    {
Chao Liu's avatar
Chao Liu committed
144
        return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
145
    }
Chao Liu's avatar
Chao Liu committed
146
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
147
    {
Chao Liu's avatar
Chao Liu committed
148
        return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
149
    }
Chao Liu's avatar
Chao Liu committed
150
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
151
    {
Chao Liu's avatar
Chao Liu committed
152
        return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
153
    }
Chao Liu's avatar
Chao Liu committed
154
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
155
    {
Chao Liu's avatar
Chao Liu committed
156
        return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{});
157
    }
Chao Liu's avatar
Chao Liu committed
158
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
159
    {
Chao Liu's avatar
Chao Liu committed
160
        return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{});
161
    }
Chao Liu's avatar
Chao Liu committed
162
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
163
    {
Chao Liu's avatar
Chao Liu committed
164
        return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{});
165
    }
Chao Liu's avatar
Chao Liu committed
166
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
167
    {
Chao Liu's avatar
Chao Liu committed
168
        return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{});
169
    }
Chao Liu's avatar
Chao Liu committed
170
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
171
    {
Chao Liu's avatar
Chao Liu committed
172
        return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{});
173
    }
Chao Liu's avatar
Chao Liu committed
174
175
    else
    {
Chao Liu's avatar
Chao Liu committed
176
        std::cout << "this data_type & layout is not implemented" << std::endl;
Chao Liu's avatar
Chao Liu committed
177

Chao Liu's avatar
Chao Liu committed
178
179
        return 1;
    }
Chao Liu's avatar
Chao Liu committed
180
}