convolution_parameter.hpp 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>

#include "ck/ck.hpp"

namespace ck {
namespace utils {
namespace conv {

struct ConvParam
{
    ConvParam();
    ConvParam(ck::index_t n_dim,
              ck::index_t group_count,
              ck::index_t n_batch,
              ck::index_t n_out_channels,
              ck::index_t n_in_channels,
              const std::vector<ck::index_t>& filters_len,
              const std::vector<ck::index_t>& input_len,
              const std::vector<ck::index_t>& strides,
              const std::vector<ck::index_t>& dilations,
              const std::vector<ck::index_t>& left_pads,
              const std::vector<ck::index_t>& right_pads);

    ck::index_t num_dim_spatial_;
    ck::index_t G_;
    ck::index_t N_;
    ck::index_t K_;
    ck::index_t C_;

    std::vector<ck::index_t> filter_spatial_lengths_;
    std::vector<ck::index_t> input_spatial_lengths_;
    std::vector<ck::index_t> output_spatial_lengths_;

    std::vector<ck::index_t> conv_filter_strides_;
    std::vector<ck::index_t> conv_filter_dilations_;

    std::vector<ck::index_t> input_left_pads_;
    std::vector<ck::index_t> input_right_pads_;

    std::vector<ck::index_t> GetOutputSpatialLengths() const;

    std::size_t GetFlops() const;

    template <typename InDataType, typename WeiDataType, typename OutDataType>
    std::size_t GetByte() const
    {
        // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
        // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
        // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
        return sizeof(InDataType) *
                   (G_ * N_ * C_ *
                    std::accumulate(std::begin(input_spatial_lengths_),
                                    std::begin(input_spatial_lengths_) + num_dim_spatial_,
                                    static_cast<std::size_t>(1),
                                    std::multiplies<std::size_t>())) +
               sizeof(WeiDataType) *
                   (G_ * K_ * C_ *
                    std::accumulate(std::begin(filter_spatial_lengths_),
                                    std::begin(filter_spatial_lengths_) + num_dim_spatial_,
                                    static_cast<std::size_t>(1),
                                    std::multiplies<std::size_t>())) +
               sizeof(OutDataType) * (G_ * N_ * K_ *
                                      std::accumulate(std::begin(output_spatial_lengths_),
                                                      std::end(output_spatial_lengths_),
                                                      static_cast<std::size_t>(1),
                                                      std::multiplies<std::size_t>()));
    }
};

std::string get_conv_param_parser_helper_msg();

ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);

} // namespace conv
} // namespace utils
} // namespace ck

std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);