profile_contraction_bilinear.cpp 8.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <vector>

#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp"

#define OP_NAME "contraction_bilinear"
#define OP_DESC "CONTRACTION+Bilinear"

static void print_helper_msg()
{
    std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
20
21
22
              << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
              << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
              << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
23
24
25
26
27
28
29
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
30
31
              << "arg5: verification (0: no; 1: yes)\n"
              << "arg6: initialization (0: no init; 1: integer value; 2: decimal "
32
              << "value)\n"
33
34
35
36
37
38
              << "arg7: print tensor value (0: no; 1: yes)\n"
              << "arg8: time kernel (0: no, 1: yes)\n"
              << "arg9: alpha\n"
              << "arg10: beta\n"
              << "arg11 to 16: M0, M1, N0, N1, K0, K1\n"
              << "arg17 to 32: Strides for A, B, D and E (skip for default)\n"
39
40
41
42
43
              << std::endl;
}

int profile_contraction_bilinear(int argc, char* argv[])
{
44
    const bool default_strides = argc == 17;
45

46
    if(argc != 33 && argc != 17)
47
48
49
50
51
52
    {
        print_helper_msg();
        exit(1);
    }

    const auto data_type          = static_cast<ContractionDataType>(std::stoi(argv[2]));
53
54
55
56
57
58
59
60
    const auto compute_data_type  = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
    const auto layout             = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
    const bool do_verification    = std::stoi(argv[5]);
    const ck::index_t init_method = std::stoi(argv[6]);
    const bool do_log             = std::stoi(argv[7]);
    const bool time_kernel        = std::stoi(argv[8]);
    const float alpha             = std::stof(argv[9]);
    const float beta              = std::stof(argv[10]);
61
62
63
64

    std::vector<ck::index_t> M;
    std::vector<ck::index_t> N;
    std::vector<ck::index_t> K;
65
    const ck::index_t dims_arg_num = 11;
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    collect_index_params(argv, M, dims_arg_num, 2);
    collect_index_params(argv, N, dims_arg_num + 2, 2);
    collect_index_params(argv, K, dims_arg_num + 4, 2);

    std::vector<ck::index_t> StridesA;
    std::vector<ck::index_t> StridesB;
    std::vector<ck::index_t> StridesE;
    std::vector<ck::index_t> StridesD;
    if(!default_strides)
    {
        collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
        collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
        collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
        collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
    }

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using F32  = float;
    using F64  = double;

    auto profile =
        [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) {
            using ALayout   = decltype(a_layout);
            using BLayout   = decltype(b_layout);
            using CDELayout = decltype(cde_layout);

            using DataType        = decltype(type);
            using ComputeDataType = decltype(compute_type);

            if(default_strides)
            {
                assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
                assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
                assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
                assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
            }
            bool pass = ck::profiler::profile_contraction_impl<ALayout,
                                                               BLayout,
                                                               CDELayout,
                                                               DataType,
                                                               ComputeDataType,
                                                               ck::Tuple<DataType>,
                                                               Bilinear>(do_verification,
                                                                         init_method,
                                                                         do_log,
                                                                         time_kernel,
                                                                         Bilinear{alpha, beta},
                                                                         M,
                                                                         N,
                                                                         K,
                                                                         StridesA,
                                                                         StridesB,
                                                                         StridesE,
                                                                         StridesD);

            return pass;
        };

    auto run_profile_for_datatype = [&](auto type, auto compute_type) {
        if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
127
        {
128
            return profile(Row{}, Row{}, Row{}, type, compute_type);
129
        }
130
131
132
133
134
135
136
137
138
139
140
141
142
        else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
        {
            return profile(Row{}, Col{}, Row{}, type, compute_type);
        }
        else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
        {
            return profile(Col{}, Row{}, Row{}, type, compute_type);
        }
        else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
        {
            return profile(Col{}, Col{}, Row{}, type, compute_type);
        }
        return false;
143
144
    };

145
    if(data_type == ContractionDataType::F32_F32_F32_F32)
146
    {
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if(compute_data_type == ContractionComputeDataType::F32)
        {
            return run_profile_for_datatype(F32{}, F32{});
        }
        else if(compute_data_type == ContractionComputeDataType::F16)
        {
            return run_profile_for_datatype(F32{}, F16{});
        }
        else if(compute_data_type == ContractionComputeDataType::BF16)
        {
            return run_profile_for_datatype(F32{}, BF16{});
        }
        else
        {
            std::cout << "Incorrect combination of data type and compute data type." << std::endl;
            return 1;
        }
164
    }
165
    else if(data_type == ContractionDataType::F64_F64_F64_F64)
166
    {
167
168
169
170
171
172
173
174
175
176
177
178
179
        if(compute_data_type == ContractionComputeDataType::F64)
        {
            return run_profile_for_datatype(F64{}, F64{});
        }
        else if(compute_data_type == ContractionComputeDataType::F32)
        {
            return run_profile_for_datatype(F64{}, F32{});
        }
        else
        {
            std::cout << "Incorrect combination of data type and compute data type." << std::endl;
            return 1;
        }
180
    }
181
    else if(data_type == ContractionDataType::F16_F16_F16_F16)
182
    {
183
184
185
186
187
188
189
190
191
        if(compute_data_type == ContractionComputeDataType::F32)
        {
            return run_profile_for_datatype(F16{}, F32{});
        }
        else
        {
            std::cout << "Incorrect combination of data type and compute data type." << std::endl;
            return 1;
        }
192
    }
193
    else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16)
194
    {
195
196
197
198
199
200
201
202
203
        if(compute_data_type == ContractionComputeDataType::F32)
        {
            return run_profile_for_datatype(BF16{}, F32{});
        }
        else
        {
            std::cout << "Incorrect combination of data type and compute data type." << std::endl;
            return 1;
        }
204
    }
205
    return 1;
206
207
208
}

REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);