profile_batched_gemm.cpp 7.42 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#include <cstdint>
zjing14's avatar
zjing14 committed
5
6
7
8
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
9

10
#include "profiler/profile_batched_gemm_impl.hpp"
11
#include "profiler_operation_registry.hpp"
zjing14's avatar
zjing14 committed
12

Chao Liu's avatar
Chao Liu committed
13
enum struct GemmMatrixLayout
zjing14's avatar
zjing14 committed
14
15
16
17
18
19
20
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
21
enum struct GemmDataType
zjing14's avatar
zjing14 committed
22
{
23
24
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
Jianfeng Yan's avatar
Jianfeng Yan committed
25
26
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
zjing14's avatar
zjing14 committed
27
28
29
30
};

int profile_batched_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
31
    if(argc != 18)
zjing14's avatar
zjing14 committed
32
    {
Chao Liu's avatar
Chao Liu committed
33
        // clang-format off
zjing14's avatar
zjing14 committed
34
        printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
Jianfeng Yan's avatar
Jianfeng Yan committed
35
        printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
zjing14's avatar
zjing14 committed
36
37
38
39
40
41
        printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n");
        printf("                     1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n");
        printf("                     2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n");
        printf("                     3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n");
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
JD's avatar
JD committed
42
43
        printf("arg6: print tensor value (0: no; 1: yes)\n");
        printf("arg7: time kernel (0=n0, 1=yes)\n");
Chao Liu's avatar
Chao Liu committed
44
45
        printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n");
        // clang-format on
zjing14's avatar
zjing14 committed
46
47
48
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
49
50
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
zjing14's avatar
zjing14 committed
51
52
53
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
54
    const bool time_kernel     = std::stoi(argv[7]);
zjing14's avatar
zjing14 committed
55
56
57
58
59
60
61
62
63

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

Chao Liu's avatar
Chao Liu committed
64
65
66
67
68
    const int BatchStrideA = std::stoi(argv[14]);
    const int BatchStrideB = std::stoi(argv[15]);
    const int BatchStrideC = std::stoi(argv[16]);

    const int BatchCount = std::stoi(argv[17]);
zjing14's avatar
zjing14 committed
69

Chao Liu's avatar
Chao Liu committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    using F32  = float;
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using INT8 = int8_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType = decltype(a_type);
        using BDataType = decltype(b_type);
        using CDataType = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

96
97
98
99
        const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
        const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
        const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;

Chao Liu's avatar
Chao Liu committed
100
101
102
103
104
105
106
        const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
        const int DefaultBatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
        const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;

        const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA;
        const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB;
        const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC;
107

Chao Liu's avatar
Chao Liu committed
108
109
110
111
112
113
114
115
116
        bool pass = ck::profiler::
            profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
                do_verification,
                init_method,
                do_log,
                time_kernel,
                M,
                N,
                K,
Chao Liu's avatar
Chao Liu committed
117
118
119
                BatchStrideA_,
                BatchStrideB_,
                BatchStrideC_,
120
121
122
                StrideA_,
                StrideB_,
                StrideC_,
Chao Liu's avatar
Chao Liu committed
123
124
125
126
127
128
                BatchCount);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
zjing14's avatar
zjing14 committed
129
    {
Chao Liu's avatar
Chao Liu committed
130
        return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
131
    }
Chao Liu's avatar
Chao Liu committed
132
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
zjing14's avatar
zjing14 committed
133
    {
Chao Liu's avatar
Chao Liu committed
134
        return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
135
    }
Chao Liu's avatar
Chao Liu committed
136
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
zjing14's avatar
zjing14 committed
137
    {
Chao Liu's avatar
Chao Liu committed
138
        return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
139
    }
Chao Liu's avatar
Chao Liu committed
140
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
zjing14's avatar
zjing14 committed
141
    {
Chao Liu's avatar
Chao Liu committed
142
        return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
143
    }
Chao Liu's avatar
Chao Liu committed
144
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
145
    {
Chao Liu's avatar
Chao Liu committed
146
        return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
147
    }
Chao Liu's avatar
Chao Liu committed
148
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
149
    {
Chao Liu's avatar
Chao Liu committed
150
        return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
151
    }
Chao Liu's avatar
Chao Liu committed
152
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
153
    {
Chao Liu's avatar
Chao Liu committed
154
        return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
155
    }
Chao Liu's avatar
Chao Liu committed
156
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
157
    {
Chao Liu's avatar
Chao Liu committed
158
        return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
159
    }
Chao Liu's avatar
Chao Liu committed
160
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
161
    {
Chao Liu's avatar
Chao Liu committed
162
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{});
163
    }
Chao Liu's avatar
Chao Liu committed
164
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
165
    {
Chao Liu's avatar
Chao Liu committed
166
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{});
167
    }
Chao Liu's avatar
Chao Liu committed
168
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
169
    {
Chao Liu's avatar
Chao Liu committed
170
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{});
171
    }
Chao Liu's avatar
Chao Liu committed
172
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
173
    {
Chao Liu's avatar
Chao Liu committed
174
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{});
175
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
176
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
177
    {
Chao Liu's avatar
Chao Liu committed
178
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
179
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
180
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
181
    {
Chao Liu's avatar
Chao Liu committed
182
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{});
183
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
184
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
185
    {
Chao Liu's avatar
Chao Liu committed
186
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{});
187
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
188
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
189
    {
Chao Liu's avatar
Chao Liu committed
190
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
191
    }
zjing14's avatar
zjing14 committed
192
193
    else
    {
Chao Liu's avatar
Chao Liu committed
194
        std::cout << "this data_type & layout is not implemented" << std::endl;
zjing14's avatar
zjing14 committed
195

Chao Liu's avatar
Chao Liu committed
196
197
        return 1;
    }
zjing14's avatar
zjing14 committed
198
}
199

200
REGISTER_PROFILER_OPERATION("batched_gemm", "Batched GEMM", profile_batched_gemm);