profile_batched_gemm.cpp 6.32 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#include <cstdint>
zjing14's avatar
zjing14 committed
5
6
7
8
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
9
10

#include "profiler/include/profile_batched_gemm_impl.hpp"
zjing14's avatar
zjing14 committed
11

Chao Liu's avatar
Chao Liu committed
12
enum struct GemmMatrixLayout
zjing14's avatar
zjing14 committed
13
14
15
16
17
18
19
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

Chao Liu's avatar
Chao Liu committed
20
enum struct GemmDataType
zjing14's avatar
zjing14 committed
21
{
22
23
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
Jianfeng Yan's avatar
Jianfeng Yan committed
24
25
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
zjing14's avatar
zjing14 committed
26
27
28
29
};

int profile_batched_gemm(int argc, char* argv[])
{
Chao Liu's avatar
Chao Liu committed
30
    if(argc != 15)
zjing14's avatar
zjing14 committed
31
32
    {
        printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
Jianfeng Yan's avatar
Jianfeng Yan committed
33
        printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
zjing14's avatar
zjing14 committed
34
35
36
37
38
39
        printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n");
        printf("                     1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n");
        printf("                     2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n");
        printf("                     3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n");
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
JD's avatar
JD committed
40
41
        printf("arg6: print tensor value (0: no; 1: yes)\n");
        printf("arg7: time kernel (0=n0, 1=yes)\n");
zjing14's avatar
zjing14 committed
42
43
44
45
        printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
46
47
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
zjing14's avatar
zjing14 committed
48
49
50
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
51
    const bool time_kernel     = std::stoi(argv[7]);
zjing14's avatar
zjing14 committed
52
53
54
55
56
57
58
59
60
61
62

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);

    const int BatchCount = std::stoi(argv[14]);

Chao Liu's avatar
Chao Liu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    using F32  = float;
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using INT8 = int8_t;

    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    auto profile = [&](auto a_type,
                       auto b_type,
                       auto c_type,
                       auto a_layout,
                       auto b_layout,
                       auto c_layout) {
        using ADataType = decltype(a_type);
        using BDataType = decltype(b_type);
        using CDataType = decltype(c_type);

        using ALayout = decltype(a_layout);
        using BLayout = decltype(b_layout);
        using CLayout = decltype(c_layout);

        const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
        const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
        const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;

        bool pass = ck::profiler::
            profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
                do_verification,
                init_method,
                do_log,
                time_kernel,
                M,
                N,
                K,
                (StrideA < 0) ? DefaultStrideA : StrideA,
                (StrideB < 0) ? DefaultStrideB : StrideB,
                (StrideC < 0) ? DefaultStrideC : StrideC,
                BatchCount);

        return pass ? 0 : 1;
    };

    if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
zjing14's avatar
zjing14 committed
107
    {
Chao Liu's avatar
Chao Liu committed
108
        return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
109
    }
Chao Liu's avatar
Chao Liu committed
110
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
zjing14's avatar
zjing14 committed
111
    {
Chao Liu's avatar
Chao Liu committed
112
        return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
113
    }
Chao Liu's avatar
Chao Liu committed
114
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
zjing14's avatar
zjing14 committed
115
    {
Chao Liu's avatar
Chao Liu committed
116
        return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
zjing14's avatar
zjing14 committed
117
    }
Chao Liu's avatar
Chao Liu committed
118
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
zjing14's avatar
zjing14 committed
119
    {
Chao Liu's avatar
Chao Liu committed
120
        return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
zjing14's avatar
zjing14 committed
121
    }
Chao Liu's avatar
Chao Liu committed
122
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
123
    {
Chao Liu's avatar
Chao Liu committed
124
        return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
125
    }
Chao Liu's avatar
Chao Liu committed
126
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
127
    {
Chao Liu's avatar
Chao Liu committed
128
        return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
129
    }
Chao Liu's avatar
Chao Liu committed
130
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
131
    {
Chao Liu's avatar
Chao Liu committed
132
        return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
133
    }
Chao Liu's avatar
Chao Liu committed
134
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
Jianfeng Yan's avatar
Jianfeng Yan committed
135
    {
Chao Liu's avatar
Chao Liu committed
136
        return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
Jianfeng Yan's avatar
Jianfeng Yan committed
137
    }
Chao Liu's avatar
Chao Liu committed
138
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
139
    {
Chao Liu's avatar
Chao Liu committed
140
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{});
141
    }
Chao Liu's avatar
Chao Liu committed
142
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
143
    {
Chao Liu's avatar
Chao Liu committed
144
        return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{});
145
    }
Chao Liu's avatar
Chao Liu committed
146
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
147
    {
Chao Liu's avatar
Chao Liu committed
148
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{});
149
    }
Chao Liu's avatar
Chao Liu committed
150
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
151
    {
Chao Liu's avatar
Chao Liu committed
152
        return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{});
153
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
154
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
155
    {
Chao Liu's avatar
Chao Liu committed
156
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
157
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
158
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
159
    {
Chao Liu's avatar
Chao Liu committed
160
        return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{});
161
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
162
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
163
    {
Chao Liu's avatar
Chao Liu committed
164
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{});
165
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
166
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
167
    {
Chao Liu's avatar
Chao Liu committed
168
        return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
169
    }
zjing14's avatar
zjing14 committed
170
171
    else
    {
Chao Liu's avatar
Chao Liu committed
172
        std::cout << "this data_type & layout is not implemented" << std::endl;
zjing14's avatar
zjing14 committed
173

Chao Liu's avatar
Chao Liu committed
174
175
        return 1;
    }
zjing14's avatar
zjing14 committed
176
}