profile_contraction_bilinear.cpp 6.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <vector>

#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp"

#define OP_NAME "contraction_bilinear"
#define OP_DESC "CONTRACTION+Bilinear"

static void print_helper_msg()
{
    std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
20
21
              << "arg2: data type (0: fp32; 1: f64)\n"
              << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
22
23
24
25
26
27
28
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
              << "                     3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
                 "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
29
30
              << "arg4: verification (0: no; 1: yes)\n"
              << "arg5: initialization (0: no init; 1: integer value; 2: decimal "
31
              << "value)\n"
32
33
34
35
36
              << "arg6: print tensor value (0: no; 1: yes)\n"
              << "arg7: time kernel (0: no, 1: yes)\n"
              << "arg8 and arg9: alpha and beta\n"
              << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
              << "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
37
38
39
40
41
              << std::endl;
}

int profile_contraction_bilinear(int argc, char* argv[])
{
42
    const bool default_strides = argc == 16;
43

44
    if(argc != 32 && argc != 16)
45
46
47
48
49
50
    {
        print_helper_msg();
        exit(1);
    }

    const auto data_type          = static_cast<ContractionDataType>(std::stoi(argv[2]));
51
52
53
54
55
56
57
    const auto layout             = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
    const bool do_verification    = std::stoi(argv[4]);
    const ck::index_t init_method = std::stoi(argv[5]);
    const bool do_log             = std::stoi(argv[6]);
    const bool time_kernel        = std::stoi(argv[7]);
    const float alpha             = std::stof(argv[8]);
    const float beta              = std::stof(argv[9]);
58
59
60
61

    std::vector<ck::index_t> M;
    std::vector<ck::index_t> N;
    std::vector<ck::index_t> K;
62
    const ck::index_t dims_arg_num = 10;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    collect_index_params(argv, M, dims_arg_num, 2);
    collect_index_params(argv, N, dims_arg_num + 2, 2);
    collect_index_params(argv, K, dims_arg_num + 4, 2);

    std::vector<ck::index_t> StridesA;
    std::vector<ck::index_t> StridesB;
    std::vector<ck::index_t> StridesE;
    std::vector<ck::index_t> StridesD;
    if(!default_strides)
    {
        collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
        collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
        collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
        collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
    }

79
80
81
82
83
84
85
86
87
88
89
    using F32 = float;
    using F64 = double;

    auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
        using ALayout   = decltype(a_layout);
        using BLayout   = decltype(b_layout);
        using CDELayout = decltype(cde_layout);

        using DataType = decltype(type);

        if(default_strides)
90
        {
91
92
93
94
            assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
            assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
            assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
            assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
95
        }
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        bool pass = ck::profiler::profile_contraction_impl<ALayout,
                                                           BLayout,
                                                           CDELayout,
                                                           DataType,
                                                           ck::Tuple<DataType>,
                                                           Bilinear>(do_verification,
                                                                     init_method,
                                                                     do_log,
                                                                     time_kernel,
                                                                     Bilinear{alpha, beta},
                                                                     M,
                                                                     N,
                                                                     K,
                                                                     StridesA,
                                                                     StridesB,
                                                                     StridesE,
                                                                     StridesD);

        return pass;
115
116
    };

117
118
    if(data_type == ContractionDataType::F32_F32_F32_F32 &&
       layout == ContractionMatrixLayout::MK_KN_MN_MN)
119
    {
120
        return profile(Row{}, Row{}, Row{}, F32{});
121
    }
122
123
    else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
            layout == ContractionMatrixLayout::MK_NK_MN_MN)
124
    {
125
        return profile(Row{}, Col{}, Row{}, F32{});
126
    }
127
128
    else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
            layout == ContractionMatrixLayout::KM_KN_MN_MN)
129
    {
130
        return profile(Col{}, Row{}, Row{}, F32{});
131
    }
132
133
    else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
            layout == ContractionMatrixLayout::KM_NK_MN_MN)
134
    {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        return profile(Col{}, Col{}, Row{}, F32{});
    }
    else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
            layout == ContractionMatrixLayout::MK_KN_MN_MN)
    {
        return profile(Row{}, Row{}, Row{}, F64{});
    }
    else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
            layout == ContractionMatrixLayout::MK_NK_MN_MN)
    {
        return profile(Row{}, Col{}, Row{}, F64{});
    }
    else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
            layout == ContractionMatrixLayout::KM_KN_MN_MN)
    {
        return profile(Col{}, Row{}, Row{}, F64{});
    }
    else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
            layout == ContractionMatrixLayout::KM_NK_MN_MN)
    {
        return profile(Col{}, Col{}, Row{}, F64{});
    }
    else
    {
        std::cout << "this data_type & layout is not implemented" << std::endl;

        return 1;
162
163
164
165
    }
}

REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);