Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
add7421f
Commit
add7421f
authored
Apr 28, 2022
by
wangshaojie6
Browse files
remove gemmk0 pad for output
parent
e6b32ffe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
7 deletions
+17
-7
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+17
-7
No files found.
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
add7421f
...
...
@@ -14,6 +14,7 @@
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#define SPLITN0_N1 1
#define GEMMK0PAD_FOR_OUT 0
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -162,27 +163,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
index_t
GemmK0Pad
=
GemmKBatch
*
GemmK0S
;
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
*
Wo
,
K
));
const
auto
out_n0_ho_wo_k_n1_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
,
N1Number
)),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Wo
),
make_pass_through_transform
(
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}
,
Sequence
<
3
>
{}
),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}
,
Sequence
<
3
>
{}
)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{})
);
const
auto
out_gemmk0total_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n0_ho_wo_k_n1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
Ho
,
Wo
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
Ho
*
Wo
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
N1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{})
);
#if GEMMK0PAD_FOR_OUT
const
auto
out_gemmk0pad_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk0total_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmK0Total
,
GemmK0Pad
-
GemmK0Total
),
...
...
@@ -198,6 +199,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform
(
N1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
#else
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk0total_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
)),
make_pass_through_transform
(
GemmM
),
make_pass_through_transform
(
N1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
#endif
#endif
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment