Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4d2172a9
Commit
4d2172a9
authored
Mar 01, 2022
by
ltqin
Browse files
raw not split version
parent
669df2d3
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
363 deletions
+97
-363
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+85
-347
example/14_conv2d_backward_weight_xdl/main.cpp
example/14_conv2d_backward_weight_xdl/main.cpp
+12
-16
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
4d2172a9
This diff is collapsed.
Click to expand it.
example/14_conv2d_backward_weight_xdl/main.cpp
View file @
4d2172a9
...
...
@@ -31,9 +31,6 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
// clang-format off
using
DeviceConvWrWInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
...
...
@@ -44,24 +41,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvFwdDefault
,
// ConvForwardSpecialization
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
6
4
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
2
,
// NXdlPerWave
S
<
4
,
1
6
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
2
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
6
4
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
4
,
1
6
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
...
...
@@ -176,9 +172,9 @@ int main(int argc, char* argv[])
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
WeiDataType
>
wei_k_c_y_x_host_result
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
,
WeiLayout
{}));
Tensor
<
WeiDataType
>
wei_k_c_y_x_device_result
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
,
WeiLayout
{}));
Tensor
<
OutDataType
>
out_n_k_ho_wo
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
Tensor
<
WeiDataType
>
wei_k_c_y_x_device_result
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
,
WeiLayout
{}));
Tensor
<
OutDataType
>
out_n_k_ho_wo
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x_host_result
.
mDesc
<<
std
::
endl
;
...
...
@@ -197,9 +193,9 @@ int main(int argc, char* argv[])
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment