Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7d6fa69a
Commit
7d6fa69a
authored
Jul 27, 2023
by
Bartlomiej Kocot
Committed by
Bartłomiej Kocot
Jul 28, 2023
Browse files
Fix strides in client examples
parent
85bddfbc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
12 deletions
+16
-12
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+1
-0
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
+3
-3
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
+3
-3
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
+3
-3
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
+3
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
...d_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
+3
-0
No files found.
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
7d6fa69a
...
...
@@ -231,6 +231,7 @@ bool run_grouped_conv_bwd_weight(
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
View file @
7d6fa69a
...
...
@@ -25,9 +25,9 @@ static constexpr ck::index_t Wo = 28;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
X
*
C
,
X
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
X
*
C
,
X
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
1
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
View file @
7d6fa69a
...
...
@@ -29,11 +29,11 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
1
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Y
*
X
*
C
,
Y
*
X
*
C
,
X
*
C
,
C
,
1
};
K
*
Y
*
X
*
C
,
Y
*
X
*
C
,
1
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
1
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
View file @
7d6fa69a
...
...
@@ -32,11 +32,11 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
Y
*
X
*
C
,
X
*
C
,
C
,
1
};
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
1
,
Y
*
X
*
C
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
1
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
View file @
7d6fa69a
...
...
@@ -32,11 +32,11 @@ static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
Y
*
X
*
C
,
X
*
C
,
C
,
1
};
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
1
,
Y
*
X
*
C
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
1
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
View file @
7d6fa69a
...
...
@@ -74,6 +74,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
...
...
@@ -86,6 +87,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
...
@@ -105,6 +107,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment