Unverified Commit 3db77bc4 authored by Mateusz Ozga's avatar Mateusz Ozga Committed by GitHub
Browse files

Simplify static_cast if-lands (#1828)

parent 3c93d3c4
...@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x) ...@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop #pragma clang diagnostic pop
} }
template <typename CompareTo, typename... Rest>
struct is_any_of : std::false_type
{
};
template <typename CompareTo, typename FirstType>
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
{
};
template <typename CompareTo, typename FirstType, typename... Rest>
struct is_any_of<CompareTo, FirstType, Rest...>
: std::integral_constant<bool,
std::is_same<CompareTo, FirstType>::value ||
is_any_of<CompareTo, Rest...>::value>
{
};
} // namespace ck_tile } // namespace ck_tile
...@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> || static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> || if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
std::is_same_v<ComputeDataType, int>)
{ {
return 0; return 0;
} }
...@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> || static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> || if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
std::is_same_v<OutDataType, int>)
{ {
return 0; return 0;
} }
...@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> || static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> || if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
std::is_same_v<AccDataType, int>)
{ {
return 0; return 0;
} }
...@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> || static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0; double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> || if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
std::is_same_v<ComputeDataType, int>)
{ {
return 0; return 0;
} }
...@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> || static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> || if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
std::is_same_v<OutDataType, int>)
{ {
return 0; return 0;
} }
...@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> || static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> || if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
std::is_same_v<AccDataType, int>)
{ {
return 0; return 0;
} }
......
...@@ -14,57 +14,41 @@ namespace detail { ...@@ -14,57 +14,41 @@ namespace detail {
template <typename OldLayout> template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old() CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{ {
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> || using namespace ck_tile::tensor_layout::convolution;
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>) if constexpr(is_any_of<OldLayout, GNCW, GKCX, GNKW>::value)
{ {
return {0, 1, 2, 3}; return {0, 1, 2, 3};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> || else if constexpr(is_any_of<OldLayout, GNCHW, GKCYX, GNKHW>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
{ {
return {0, 1, 2, 3, 4}; return {0, 1, 2, 3, 4};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> || else if constexpr(is_any_of<OldLayout, GNCDHW, GKCZYX, GNKDHW>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{ {
return {0, 1, 2, 3, 4, 5}; return {0, 1, 2, 3, 4, 5};
} }
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> || if constexpr(is_any_of<OldLayout, GNWC, GKXC, GNWK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
{ {
return {0, 1, 3, 2}; return {0, 1, 3, 2};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> || else if constexpr(is_any_of<OldLayout, GNHWC, GKYXC, GNHWK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
{ {
return {0, 1, 4, 2, 3}; return {0, 1, 4, 2, 3};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> || else if constexpr(is_any_of<OldLayout, GNDHWC, GKZYXC, GNDHWK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{ {
return {0, 1, 5, 2, 3, 4}; return {0, 1, 5, 2, 3, 4};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> || else if constexpr(is_any_of<OldLayout, NWGC, KXGC, NWGK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
{ {
return {2, 0, 3, 1}; return {2, 0, 3, 1};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> || else if constexpr(is_any_of<OldLayout, NHWGC, KYXGC, NHWGK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
{ {
return {3, 0, 4, 1, 2}; return {3, 0, 4, 1, 2};
} }
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> || else if constexpr(is_any_of<OldLayout, NDHWGC, KZYXGC, NDHWGK>::value)
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{ {
return {4, 0, 5, 1, 2, 3}; return {4, 0, 5, 1, 2, 3};
} }
...@@ -83,11 +67,11 @@ template <typename InLayout> ...@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param) make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{ {
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths; std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> || if constexpr(is_any_of<InLayout, GNCW, GNCHW, GNCDHW>::value)
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_), static_cast<std::size_t>(param.N_),
...@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara ...@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(), param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_); param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> || else if constexpr(is_any_of<InLayout, GNWC, GNHWC, GNDHWC>::value)
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_), static_cast<std::size_t>(param.N_),
...@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara ...@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(), param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_); param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> || else if constexpr(is_any_of<InLayout, NWGC, NHWGC, NDHWGC>::value)
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_), static_cast<std::size_t>(param.G_),
...@@ -139,11 +119,11 @@ template <typename WeiLayout> ...@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param) make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{ {
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths; std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> || if constexpr(is_any_of<WeiLayout, KXC, KYXC, KZYXC>::value)
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
{ {
if(param.G_ != 1) if(param.G_ != 1)
{ {
...@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara ...@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> || else if constexpr(is_any_of<WeiLayout, GKCX, GKCYX, GKCZYX>::value)
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_), static_cast<std::size_t>(param.K_),
...@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara ...@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> || else if constexpr(is_any_of<WeiLayout, GKXC, GKYXC, GKZYXC>::value)
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_), static_cast<std::size_t>(param.K_),
...@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara ...@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(), param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> || else if constexpr(is_any_of<WeiLayout, KXGC, KYXGC, KZYXGC>::value)
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_), static_cast<std::size_t>(param.G_),
...@@ -211,11 +185,11 @@ template <typename OutLayout> ...@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param) make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{ {
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths; std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> || if constexpr(is_any_of<OutLayout, GNKW, GNKHW, GNKDHW>::value)
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_), static_cast<std::size_t>(param.N_),
...@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar ...@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin() + param.num_dim_spatial_); param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
// separate from legacy code above // separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> || else if constexpr(is_any_of<OutLayout, GNWK, GNHWK, GNDHWK>::value)
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_), static_cast<std::size_t>(param.N_),
...@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar ...@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin(), param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_); param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> || else if constexpr(is_any_of<OutLayout, NWGK, NHWGK, NDHWGK>::value)
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_), static_cast<std::size_t>(param.G_),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment