Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
669df2d3
Commit
669df2d3
authored
Mar 01, 2022
by
ltqin
Browse files
start device
parent
cfc80c01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
16 deletions
+18
-16
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+15
-13
example/14_conv2d_backward_weight_xdl/main.cpp
example/14_conv2d_backward_weight_xdl/main.cpp
+3
-3
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
669df2d3
...
...
@@ -54,14 +54,14 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dWr
w
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConv2dWr
W
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv2dWr
w
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
DeviceOp
=
DeviceConv2dWr
W
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
BDataType
=
Wei
DataType
;
using
CDataType
=
Out
DataType
;
using
BDataType
=
Out
DataType
;
using
CDataType
=
Wei
DataType
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
...
...
@@ -432,10 +432,11 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
OutElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
*
K1
,
K1
,
// AK1
K1
,
// BK1
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -491,8 +492,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
M01_
{
M01
},
N01_
{
N01
},
in_element_op_
{
in_element_op
},
wei_element_op_
{
out
_element_op
},
out_element_op_
{
wei
_element_op
},
wei_element_op_
{
wei
_element_op
},
out_element_op_
{
out
_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
...
...
@@ -525,7 +526,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
...
@@ -538,7 +540,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
Default
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
in_element_op_
;
...
...
@@ -628,7 +630,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Default
Block2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
...
...
@@ -662,7 +664,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Default
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
...
...
@@ -838,7 +840,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv2d
Fwd
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str
<<
"DeviceConv2d
WrW
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
example/14_conv2d_backward_weight_xdl/main.cpp
View file @
669df2d3
...
...
@@ -35,8 +35,8 @@ static constexpr auto ConvFwdDefault =
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
// clang-format off
using
DeviceConv
Fwd
Instance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dWr
w
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
using
DeviceConv
WrW
Instance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dWr
W
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
...
...
@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
// do GEMM
auto
conv
=
DeviceConv
Fwd
Instance
{};
auto
conv
=
DeviceConv
WrW
Instance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment