"configs/vscode:/vscode.git/clone" did not exist on "32ab994d76353f7a34ae772984a5f9ee97da6b7e"
profile_batched_gemm.cpp 6.77 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#include <cstdint>
zjing14's avatar
zjing14 committed
5
6
7
8
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
9
10

#include "profiler/include/profile_batched_gemm_impl.hpp"
zjing14's avatar
zjing14 committed
11

Chao Liu's avatar
Chao Liu committed
12
enum struct GemmMatrixLayout
zjing14's avatar
zjing14 committed
13
14
15
16
17
18
19
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
20
enum struct GemmDataType
zjing14's avatar
zjing14 committed
21
{
22
23
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
Jianfeng Yan's avatar
Jianfeng Yan committed
24
25
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
zjing14's avatar
zjing14 committed
26
27
28
29
};

int profile_batched_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
30
    if(argc != 15)
zjing14's avatar
zjing14 committed
31
32
    {
        printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
Jianfeng Yan's avatar
Jianfeng Yan committed
33
        printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
zjing14's avatar
zjing14 committed
34
35
36
37
38
39
        printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n");
        printf("                     1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n");
        printf("                     2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n");
        printf("                     3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n");
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
JD's avatar
JD committed
40
41
        printf("arg6: print tensor value (0: no; 1: yes)\n");
        printf("arg7: time kernel (0=n0, 1=yes)\n");
zjing14's avatar
zjing14 committed
42
43
44
45
        printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
46
47
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
zjing14's avatar
zjing14 committed
48
49
50
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
51
    const bool time_kernel     = std::stoi(argv[7]);
zjing14's avatar
zjing14 committed
52
53
54
55
56
57
58
59
60
61
62

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

    const int BatchCount = std::stoi(argv[14]);

Chao Liu's avatar
Chao Liu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    using F32  = float;
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using INT8 = int8_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType = decltype(a_type);
        using BDataType = decltype(b_type);
        using CDataType = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

89
90
91
92
93
94
95
96
        const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
        const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
        const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;

        const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
        const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
        const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;

Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102
103
104
105
        bool pass = ck::profiler::
            profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
                do_verification,
                init_method,
                do_log,
                time_kernel,
                M,
                N,
                K,
106
107
108
109
110
111
                BatchStrideA,
                BatchStrideB,
                BatchStrideC,
                StrideA_,
                StrideB_,
                StrideC_,
Chao Liu's avatar
Chao Liu committed
112
113
114
115
116
117
                BatchCount);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
zjing14's avatar
zjing14 committed
118
    {
Chao Liu's avatar
Chao Liu committed
119
        return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
120
    }
Chao Liu's avatar
Chao Liu committed
121
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
zjing14's avatar
zjing14 committed
122
    {
Chao Liu's avatar
Chao Liu committed
123
        return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
124
    }
Chao Liu's avatar
Chao Liu committed
125
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
zjing14's avatar
zjing14 committed
126
    {
Chao Liu's avatar
Chao Liu committed
127
        return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
128
    }
Chao Liu's avatar
Chao Liu committed
129
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
zjing14's avatar
zjing14 committed
130
    {
Chao Liu's avatar
Chao Liu committed
131
        return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
132
    }
Chao Liu's avatar
Chao Liu committed
133
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
134
    {
Chao Liu's avatar
Chao Liu committed
135
        return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
136
    }
Chao Liu's avatar
Chao Liu committed
137
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
138
    {
Chao Liu's avatar
Chao Liu committed
139
        return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
140
    }
Chao Liu's avatar
Chao Liu committed
141
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
142
    {
Chao Liu's avatar
Chao Liu committed
143
        return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
144
    }
Chao Liu's avatar
Chao Liu committed
145
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
146
    {
Chao Liu's avatar
Chao Liu committed
147
        return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
148
    }
Chao Liu's avatar
Chao Liu committed
149
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
150
    {
Chao Liu's avatar
Chao Liu committed
151
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{});
152
    }
Chao Liu's avatar
Chao Liu committed
153
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
154
    {
Chao Liu's avatar
Chao Liu committed
155
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{});
156
    }
Chao Liu's avatar
Chao Liu committed
157
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
158
    {
Chao Liu's avatar
Chao Liu committed
159
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{});
160
    }
Chao Liu's avatar
Chao Liu committed
161
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
162
    {
Chao Liu's avatar
Chao Liu committed
163
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{});
164
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
165
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
166
    {
Chao Liu's avatar
Chao Liu committed
167
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
168
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
169
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
170
    {
Chao Liu's avatar
Chao Liu committed
171
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{});
172
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
173
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
174
    {
Chao Liu's avatar
Chao Liu committed
175
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{});
176
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
177
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
178
    {
Chao Liu's avatar
Chao Liu committed
179
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
180
    }
zjing14's avatar
zjing14 committed
181
182
    else
    {
Chao Liu's avatar
Chao Liu committed
183
        std::cout << "this data_type & layout is not implemented" << std::endl;
zjing14's avatar
zjing14 committed
184

Chao Liu's avatar
Chao Liu committed
185
186
        return 1;
    }
zjing14's avatar
zjing14 committed
187
}