Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3db77bc4
"test/vscode:/vscode.git/clone" did not exist on "15ddd84322d431912b522ab0aed0b91056d4ef1c"
Unverified
Commit
3db77bc4
authored
Jan 21, 2025
by
Mateusz Ozga
Committed by
GitHub
Jan 21, 2025
Browse files
Simplify static_cast if-lands (#1828)
parent
3c93d3c4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
93 deletions
+63
-93
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+18
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+18
-36
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+27
-57
No files found.
include/ck_tile/core/utility/type_traits.hpp
View file @
3db77bc4
...
...
@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
}
template
<
typename
CompareTo
,
typename
...
Rest
>
struct
is_any_of
:
std
::
false_type
{
};
template
<
typename
CompareTo
,
typename
FirstType
>
struct
is_any_of
<
CompareTo
,
FirstType
>
:
std
::
is_same
<
CompareTo
,
FirstType
>
{
};
template
<
typename
CompareTo
,
typename
FirstType
,
typename
...
Rest
>
struct
is_any_of
<
CompareTo
,
FirstType
,
Rest
...
>
:
std
::
integral_constant
<
bool
,
std
::
is_same
<
CompareTo
,
FirstType
>::
value
||
is_any_of
<
CompareTo
,
Rest
...
>::
value
>
{
};
}
// namespace ck_tile
include/ck_tile/host/check_err.hpp
View file @
3db77bc4
...
...
@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
View file @
3db77bc4
...
...
@@ -14,57 +14,41 @@ namespace detail {
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
if
constexpr
(
is_any_of
<
OldLayout
,
GNCW
,
GKCX
,
GNKW
>::
value
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCHW
,
GKCYX
,
GNKHW
>::
value
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCDHW
,
GKCZYX
,
GNKDHW
>::
value
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
if
constexpr
(
is_any_of
<
OldLayout
,
GNWC
,
GKXC
,
GNWK
>::
value
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNHWC
,
GKYXC
,
GNHWK
>::
value
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNDHWC
,
GKZYXC
,
GNDHWK
>::
value
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NWGC
,
KXGC
,
NWGK
>::
value
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NHWGC
,
KYXGC
,
NHWGK
>::
value
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NDHWGC
,
KZYXGC
,
NDHWGK
>::
value
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
...
...
@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
if
constexpr
(
is_any_of
<
InLayout
,
GNCW
,
GNCHW
,
GNCDHW
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
else
if
constexpr
(
is_any_of
<
InLayout
,
GNWC
,
GNHWC
,
GNDHWC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
else
if
constexpr
(
is_any_of
<
InLayout
,
NWGC
,
NHWGC
,
NDHWGC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
if
constexpr
(
is_any_of
<
WeiLayout
,
KXC
,
KYXC
,
KZYXC
>::
value
)
{
if
(
param
.
G_
!=
1
)
{
...
...
@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKCX
,
GKCYX
,
GKCZYX
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
...
@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKXC
,
GKYXC
,
GKZYXC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
...
@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
KXGC
,
KYXGC
,
KZYXGC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
if
constexpr
(
is_any_of
<
OutLayout
,
GNKW
,
GNKHW
,
GNKDHW
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
else
if
constexpr
(
is_any_of
<
OutLayout
,
GNWK
,
GNHWK
,
GNDHWK
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
else
if
constexpr
(
is_any_of
<
OutLayout
,
NWGK
,
NHWGK
,
NDHWGK
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment