profile_normalization.cpp 6.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <vector>
#include <unordered_map>

#include "profiler/include/profile_normalization_impl.hpp"

using ck::index_t;
using ck::profiler::NormDataType;
using ck::profiler::NormType;

struct ArgParser
{
rocking5566's avatar
rocking5566 committed
16
    std::unordered_map<std::string, NormType> norm_dict = {{"batchnorm", NormType::BATCHNORM},
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
                                                           {"softmax", NormType::SOFTMAX}};

    std::unordered_map<std::string, std::vector<int>> long_opts = {
        {"length", {}}, {"stride", {}}, {"reduce", {}}, {"alpha", {}}, {"beta", {}}};

    bool parse_opt(int argc, char* argv[], const std::string& key, int i)
    {
        if(std::string("--") + key == argv[i])
        {
            int pos = i;
            while(++i < argc && argv[i][0] != '-') {}
            int end = i;
            for(int j = pos + 1; j < end; j++)
            {
                long_opts[key].push_back(std::stoi(argv[j]));
            }
            return true;
        }
        return false;
    }

    void operator()(int argc, char* argv[])
    {
        for(auto& kv : long_opts)
        {
            for(int i = 1; i < argc; i++)
            {
                if(parse_opt(argc, argv, kv.first, i))
                    break;
            }
        }
    }
};

void print_help()
{
Adam Osewski's avatar
Adam Osewski committed
53
    std::cout << "arg1: tensor operation (batchnorm/softmax)\n"
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
              << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
              << "arg3: verification (0: no; 1: yes)\n"
              << "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n"
              << "arg5: print tensor value (0: no; 1: yes)\n"
              << "arg6: time kernel (0=n0, 1=yes)\n"
              << "--length: tensor extents (e.g, --length 8 4 256) \n"
              << "--stride: tensor strides (e.g, --stride 1024 256 1)\n"
              << "--reduce: to-reduce dimensions (e.g, --reduce 2)\n"
              << "--alpha: alpha scaling value\n"
              << "--beta: beta scaling value\n"
              << std::endl;
}

int profile_normalization(int argc, char* argv[])
{
    if(argc <= 2)
    {
        print_help();
        return 0;
    }

    ArgParser arg_parser;

    // short unnamed options
    const NormType norm_type     = arg_parser.norm_dict[argv[1]];
    const NormDataType data_type = static_cast<NormDataType>(std::stoi(argv[2]));
    const bool do_verification   = std::stoi(argv[3]);
    const int init_method        = std::stoi(argv[4]);
    const bool do_log            = std::stoi(argv[5]);
    const bool time_kernel       = std::stoi(argv[6]);

    // parse the long options
    arg_parser(argc, argv);
    const std::vector<index_t> length = arg_parser.long_opts["length"];
    const std::vector<index_t> stride = arg_parser.long_opts["stride"];
    const std::vector<index_t> reduce = arg_parser.long_opts["reduce"];
    const index_t alpha =
        arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0];
    const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0];

Adam Osewski's avatar
Adam Osewski committed
94
    if(length.size() == 3)
95
    {
Adam Osewski's avatar
Adam Osewski committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        if(data_type == NormDataType::F16_F16)
        {
            ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 3>(
                do_verification,
                init_method,
                do_log,
                time_kernel,
                length,
                stride,
                reduce,
                float(alpha),
                float(beta),
                norm_type);
        }
        else if(data_type == NormDataType::F32_F32)
        {
            ck::profiler::profile_normalization_impl<float, float, float, 3>(do_verification,
                                                                             init_method,
                                                                             do_log,
                                                                             time_kernel,
                                                                             length,
                                                                             stride,
                                                                             reduce,
                                                                             float(alpha),
                                                                             float(beta),
                                                                             norm_type);
        }
        else
        {
            throw std::runtime_error("not implemented yet");
        }
127
    }
Adam Osewski's avatar
Adam Osewski committed
128
    else if(length.size() == 4)
129
    {
Adam Osewski's avatar
Adam Osewski committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if(data_type == NormDataType::F16_F16)
        {
            ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 4>(
                do_verification,
                init_method,
                do_log,
                time_kernel,
                length,
                stride,
                reduce,
                float(alpha),
                float(beta),
                norm_type);
        }
        else if(data_type == NormDataType::F32_F32)
        {
            ck::profiler::profile_normalization_impl<float, float, float, 4>(do_verification,
                                                                             init_method,
                                                                             do_log,
                                                                             time_kernel,
                                                                             length,
                                                                             stride,
                                                                             reduce,
                                                                             float(alpha),
                                                                             float(beta),
                                                                             norm_type);
        }
        else
        {
            throw std::runtime_error("not implemented yet");
        }
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    }
    else
    {
        throw std::runtime_error("not implemented yet");
    }

    return 0;
}

// hijack main() for quick debugging
// int main(int argc, char* argv[])
// {
//     profile_normalization(argc, argv);
//     return 0;
// }