convolution_parameter.hpp 3.15 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>

#include "ck/ck.hpp"

namespace ck {
Chao Liu's avatar
clean  
Chao Liu committed
14
15
namespace utils {
namespace conv {
Chao Liu's avatar
Chao Liu committed
16

Chao Liu's avatar
clean  
Chao Liu committed
17
struct ConvParam
Chao Liu's avatar
Chao Liu committed
18
{
Chao Liu's avatar
clean  
Chao Liu committed
19
20
    ConvParam();
    ConvParam(ck::index_t n_dim,
Chao Liu's avatar
add G  
Chao Liu committed
21
              ck::index_t group_count,
Chao Liu's avatar
clean  
Chao Liu committed
22
23
24
25
26
27
28
29
30
              ck::index_t n_batch,
              ck::index_t n_out_channels,
              ck::index_t n_in_channels,
              const std::vector<ck::index_t>& filters_len,
              const std::vector<ck::index_t>& input_len,
              const std::vector<ck::index_t>& strides,
              const std::vector<ck::index_t>& dilations,
              const std::vector<ck::index_t>& left_pads,
              const std::vector<ck::index_t>& right_pads);
Chao Liu's avatar
Chao Liu committed
31
32

    ck::index_t num_dim_spatial_;
Chao Liu's avatar
add G  
Chao Liu committed
33
    ck::index_t G_;
Chao Liu's avatar
Chao Liu committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    ck::index_t N_;
    ck::index_t K_;
    ck::index_t C_;

    std::vector<ck::index_t> filter_spatial_lengths_;
    std::vector<ck::index_t> input_spatial_lengths_;
    std::vector<ck::index_t> output_spatial_lengths_;

    std::vector<ck::index_t> conv_filter_strides_;
    std::vector<ck::index_t> conv_filter_dilations_;

    std::vector<ck::index_t> input_left_pads_;
    std::vector<ck::index_t> input_right_pads_;

    std::vector<ck::index_t> GetOutputSpatialLengths() const;

    std::size_t GetFlops() const;

    template <typename InDataType, typename WeiDataType, typename OutDataType>
    std::size_t GetByte() const
    {
Chao Liu's avatar
add G  
Chao Liu committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
        // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
        // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
        return sizeof(InDataType) *
                   (G_ * N_ * C_ *
                    std::accumulate(std::begin(input_spatial_lengths_),
                                    std::begin(input_spatial_lengths_) + num_dim_spatial_,
                                    static_cast<std::size_t>(1),
                                    std::multiplies<std::size_t>())) +
               sizeof(WeiDataType) *
                   (G_ * K_ * C_ *
                    std::accumulate(std::begin(filter_spatial_lengths_),
                                    std::begin(filter_spatial_lengths_) + num_dim_spatial_,
                                    static_cast<std::size_t>(1),
                                    std::multiplies<std::size_t>())) +
               sizeof(OutDataType) * (G_ * N_ * K_ *
Chao Liu's avatar
Chao Liu committed
71
72
73
74
75
76
77
                                      std::accumulate(std::begin(output_spatial_lengths_),
                                                      std::end(output_spatial_lengths_),
                                                      static_cast<std::size_t>(1),
                                                      std::multiplies<std::size_t>()));
    }
};

Chao Liu's avatar
clean  
Chao Liu committed
78
79
} // namespace conv
} // namespace utils
Chao Liu's avatar
Chao Liu committed
80
81
} // namespace ck

Chao Liu's avatar
clean  
Chao Liu committed
82
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
Chao Liu's avatar
Chao Liu committed
83
84
85
86

std::string get_conv_param_parser_helper_msg();

ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);