// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include "ck/ck.hpp" namespace ck { namespace utils { namespace conv { struct ConvParam { ConvParam(); ConvParam(ck::index_t n_dim, ck::index_t group_count, ck::index_t n_batch, ck::index_t n_out_channels, ck::index_t n_in_channels, const std::vector& filters_len, const std::vector& input_len, const std::vector& strides, const std::vector& dilations, const std::vector& left_pads, const std::vector& right_pads); ck::index_t num_dim_spatial_; ck::index_t G_; ck::index_t N_; ck::index_t K_; ck::index_t C_; std::vector filter_spatial_lengths_; std::vector input_spatial_lengths_; std::vector output_spatial_lengths_; std::vector conv_filter_strides_; std::vector conv_filter_dilations_; std::vector input_left_pads_; std::vector input_right_pads_; std::vector GetOutputSpatialLengths() const; std::size_t GetFlops() const; template std::size_t GetByte() const { // sizeof(InDataType) * (G * N * C * ) + // sizeof(WeiDataType) * (G * K * C * ) + // sizeof(OutDataType) * (G * N * K * ); return sizeof(InDataType) * (G_ * N_ * C_ * std::accumulate(std::begin(input_spatial_lengths_), std::begin(input_spatial_lengths_) + num_dim_spatial_, static_cast(1), std::multiplies())) + sizeof(WeiDataType) * (G_ * K_ * C_ * std::accumulate(std::begin(filter_spatial_lengths_), std::begin(filter_spatial_lengths_) + num_dim_spatial_, static_cast(1), std::multiplies())) + sizeof(OutDataType) * (G_ * N_ * K_ * std::accumulate(std::begin(output_spatial_lengths_), std::end(output_spatial_lengths_), static_cast(1), std::multiplies())); } }; std::string get_conv_param_parser_helper_msg(); ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]); } // namespace conv } // namespace utils } // namespace ck std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);