"include/ck/utility/math.hpp" did not exist on "7d09790a0ac2be6c150d25654fdd9d05d392b34f"
convolution_parameter.hpp 3.88 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12

#pragma once

#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>

#include "ck/ck.hpp"

13
14
#include "ck/library/utility/numeric.hpp"

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
namespace ck {
namespace utils {
namespace conv {

struct ConvParam
{
    ConvParam();
    ConvParam(ck::index_t n_dim,
              ck::index_t group_count,
              ck::index_t n_batch,
              ck::index_t n_out_channels,
              ck::index_t n_in_channels,
              const std::vector<ck::index_t>& filters_len,
              const std::vector<ck::index_t>& input_len,
              const std::vector<ck::index_t>& strides,
              const std::vector<ck::index_t>& dilations,
              const std::vector<ck::index_t>& left_pads,
              const std::vector<ck::index_t>& right_pads);

34
35
36
37
38
39
40
41
42
43
44
    ConvParam(ck::long_index_t n_dim,
              ck::long_index_t group_count,
              ck::long_index_t n_batch,
              ck::long_index_t n_out_channels,
              ck::long_index_t n_in_channels,
              const std::vector<ck::long_index_t>& filters_len,
              const std::vector<ck::long_index_t>& input_len,
              const std::vector<ck::long_index_t>& strides,
              const std::vector<ck::long_index_t>& dilations,
              const std::vector<ck::long_index_t>& left_pads,
              const std::vector<ck::long_index_t>& right_pads);
45

46
47
48
49
50
    ck::long_index_t num_dim_spatial_;
    ck::long_index_t G_;
    ck::long_index_t N_;
    ck::long_index_t K_;
    ck::long_index_t C_;
51

52
53
54
    std::vector<ck::long_index_t> filter_spatial_lengths_;
    std::vector<ck::long_index_t> input_spatial_lengths_;
    std::vector<ck::long_index_t> output_spatial_lengths_;
55

56
57
    std::vector<ck::long_index_t> conv_filter_strides_;
    std::vector<ck::long_index_t> conv_filter_dilations_;
58

59
60
61
62
    std::vector<ck::long_index_t> input_left_pads_;
    std::vector<ck::long_index_t> input_right_pads_;

    std::vector<ck::long_index_t> GetOutputSpatialLengths() const;
63
64
65

    std::size_t GetFlops() const;

66
67
    template <typename InDataType>
    std::size_t GetInputByte() const
68
69
    {
        // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
70
71
        return sizeof(InDataType) *
               (G_ * N_ * C_ *
72
73
                ck::accumulate_n<std::size_t>(
                    std::begin(input_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
74
75
76
77
78
    }

    template <typename WeiDataType>
    std::size_t GetWeightByte() const
    {
79
        // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
80
81
        return sizeof(WeiDataType) *
               (G_ * K_ * C_ *
82
83
                ck::accumulate_n<std::size_t>(
                    std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
84
85
86
87
88
    }

    template <typename OutDataType>
    std::size_t GetOutputByte() const
    {
89
        // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
90
        return sizeof(OutDataType) * (G_ * N_ * K_ *
91
92
93
94
95
                                      std::accumulate(std::begin(output_spatial_lengths_),
                                                      std::end(output_spatial_lengths_),
                                                      static_cast<std::size_t>(1),
                                                      std::multiplies<std::size_t>()));
    }
96
97
98
99
100
101
102

    template <typename InDataType, typename WeiDataType, typename OutDataType>
    std::size_t GetByte() const
    {
        return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
               GetOutputByte<OutDataType>();
    }
103
104
105
106
107
108
109
110
111
112
113
};

std::string get_conv_param_parser_helper_msg();

ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);

} // namespace conv
} // namespace utils
} // namespace ck

std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);