Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a75b6800
Commit
a75b6800
authored
Mar 07, 2022
by
ltqin
Browse files
change wrw to bwd wgt
parent
e17c0d80
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
36 additions
and
29 deletions
+36
-29
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+9
-6
device_operation/include/device_conv_backward_weight.hpp
device_operation/include/device_conv_backward_weight.hpp
+5
-5
example/13_conv2d_backward_weight_xdl/README.md
example/13_conv2d_backward_weight_xdl/README.md
+5
-5
example/13_conv2d_backward_weight_xdl/main.cpp
example/13_conv2d_backward_weight_xdl/main.cpp
+10
-6
example/CMakeLists.txt
example/CMakeLists.txt
+3
-3
reference_operation/include/reference_conv_backward_weight.hpp
...ence_operation/include/reference_conv_backward_weight.hpp
+4
-4
No files found.
device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
a75b6800
#ifndef DEVICE_CONV2D_
WRW
_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#ifndef DEVICE_CONV2D_
BWD_WGT
_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_
WRW
_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_
BWD_WGT
_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -52,10 +52,13 @@ template <typename InDataType,
...
@@ -52,10 +52,13 @@ template <typename InDataType,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvBwdWgt
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
DeviceOp
=
DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
OutDataType
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
BDataType
=
InDataType
;
...
@@ -691,7 +694,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -691,7 +694,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceConv2d
WrW
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str
<<
"DeviceConv2d
BwdWgt
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
device_operation/include/device_conv_backward_weight.hpp
View file @
a75b6800
#ifndef DEVICE_CONV_
WRW
_HPP
#ifndef DEVICE_CONV_
BWD_WGT
_HPP
#define DEVICE_CONV_
WRW
_HPP
#define DEVICE_CONV_
BWD_WGT
_HPP
#include <iostream>
#include <iostream>
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -11,7 +11,7 @@ namespace device {
...
@@ -11,7 +11,7 @@ namespace device {
template
<
typename
InElementwiseOperation
,
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
struct
DeviceConv
Wrw
:
public
BaseOperator
struct
DeviceConv
BwdWgt
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
...
@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
...
@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template
<
typename
InElementwiseOperation
,
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
using
DeviceConv
Wrw
Ptr
=
std
::
unique_ptr
<
using
DeviceConv
BwdWgt
Ptr
=
std
::
unique_ptr
<
DeviceConv
Wrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
DeviceConv
BwdWgt
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
example/13_conv2d_backward_weight_xdl/README.md
View file @
a75b6800
# Instructions for ```conv2d_
wrw
_xdl``` Example
# Instructions for ```conv2d_
bwd_wgt
_xdl``` Example
## Docker script
## Docker script
```
bash
```
bash
...
@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \
...
@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
/bin/bash
```
```
## Build ```conv2d_
wrw
_xdl```
## Build ```conv2d_
bwd_wgt
_xdl```
```
bash
```
bash
mkdir
build
&&
cd
build
mkdir
build
&&
cd
build
```
```
...
@@ -30,17 +30,17 @@ cmake \
...
@@ -30,17 +30,17 @@ cmake \
```
```
```
bash
```
bash
make
-j
conv2d_
wrw
_xdl
make
-j
conv2d_
bwd_wgt
_xdl
```
```
## Run ```conv2d_
wrw
_xdl```
## Run ```conv2d_
bwd_wgt
_xdl```
```
bash
```
bash
#arg1: verification (0=no, 1=yes)
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg3: run kernel # of times (>1)
#arg4: is show log (0=no, 1=yes)
#arg4: is show log (0=no, 1=yes)
#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k
#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k
./example/conv2d_
f
wd_xdl 0 1 5 0 4
./example/conv2d_
b
wd_
wgt_
xdl 0 1 5 0 4
```
```
Result
Result
...
...
example/13_conv2d_backward_weight_xdl/main.cpp
View file @
a75b6800
...
@@ -32,8 +32,8 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -32,8 +32,8 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
// clang-format off
using
DeviceConv
WrW
Instance
=
ck
::
tensor_operation
::
device
::
using
DeviceConv
BwdWgt
Instance
=
ck
::
tensor_operation
::
device
::
DeviceConv2d
WrW
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DeviceConv2d
BwdWgt
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
...
@@ -70,8 +70,12 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
...
@@ -70,8 +70,12 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceConvWrwInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceConvBwdWgtInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWgt
<
InDataType
,
ReferenceConvWrw
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -211,7 +215,7 @@ int main(int argc, char* argv[])
...
@@ -211,7 +215,7 @@ int main(int argc, char* argv[])
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
// do GEMM
// do GEMM
auto
conv
=
DeviceConv
WrW
Instance
{};
auto
conv
=
DeviceConv
BwdWgt
Instance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
...
@@ -256,7 +260,7 @@ int main(int argc, char* argv[])
...
@@ -256,7 +260,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
auto
ref_conv
=
ReferenceConv
Wrw
Instance
{};
auto
ref_conv
=
ReferenceConv
BwdWgt
Instance
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
...
...
example/CMakeLists.txt
View file @
a75b6800
...
@@ -24,7 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw
...
@@ -24,7 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw
set
(
CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp
)
set
(
CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp
)
set
(
GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp
)
set
(
GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp
)
set
(
CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp
)
set
(
CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp
)
set
(
CONV2D_
WRW
_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp
)
set
(
CONV2D_
BWD_WGT
_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp
)
set
(
CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp
)
set
(
CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp
)
set
(
CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp
)
set
(
CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp
)
set
(
CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp
)
set
(
CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp
)
...
@@ -42,7 +42,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC
...
@@ -42,7 +42,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC
add_executable
(
conv2d_fwd_xdl_bias_relu_atomic_add
${
CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE
}
)
add_executable
(
conv2d_fwd_xdl_bias_relu_atomic_add
${
CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE
}
)
add_executable
(
gemm_xdl_alpha_beta
${
GEMM_XDL_ALPHA_BETA_SOURCE
}
)
add_executable
(
gemm_xdl_alpha_beta
${
GEMM_XDL_ALPHA_BETA_SOURCE
}
)
add_executable
(
conv2d_fwd_xdl_int8
${
CONV2D_FWD_XDL_INT8_SOURCE
}
)
add_executable
(
conv2d_fwd_xdl_int8
${
CONV2D_FWD_XDL_INT8_SOURCE
}
)
add_executable
(
conv2d_
wrw
_xdl
${
CONV2D_
WRW
_XDL_SOURCE
}
)
add_executable
(
conv2d_
bwd_wgt
_xdl
${
CONV2D_
BWD_WGT
_XDL_SOURCE
}
)
add_executable
(
conv3d_fwd_xdl
${
CONV3D_FWD_XDL_SOURCE
}
)
add_executable
(
conv3d_fwd_xdl
${
CONV3D_FWD_XDL_SOURCE
}
)
add_executable
(
convnd_fwd_xdl
${
CONVND_FWD_XDL_SOURCE
}
)
add_executable
(
convnd_fwd_xdl
${
CONVND_FWD_XDL_SOURCE
}
)
add_executable
(
conv2d_bwd_data_xdl
${
CONV2D_BWD_DATA_XDL_SOURCE
}
)
add_executable
(
conv2d_bwd_data_xdl
${
CONV2D_BWD_DATA_XDL_SOURCE
}
)
...
@@ -60,7 +60,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor)
...
@@ -60,7 +60,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor)
target_link_libraries
(
conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor
)
target_link_libraries
(
conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor
)
target_link_libraries
(
gemm_xdl_alpha_beta PRIVATE host_tensor
)
target_link_libraries
(
gemm_xdl_alpha_beta PRIVATE host_tensor
)
target_link_libraries
(
conv2d_fwd_xdl_int8 PRIVATE host_tensor
)
target_link_libraries
(
conv2d_fwd_xdl_int8 PRIVATE host_tensor
)
target_link_libraries
(
conv2d_
wrw
_xdl PRIVATE host_tensor
)
target_link_libraries
(
conv2d_
bwd_wgt
_xdl PRIVATE host_tensor
)
target_link_libraries
(
conv3d_fwd_xdl PRIVATE host_tensor
)
target_link_libraries
(
conv3d_fwd_xdl PRIVATE host_tensor
)
target_link_libraries
(
convnd_fwd_xdl PRIVATE host_tensor
)
target_link_libraries
(
convnd_fwd_xdl PRIVATE host_tensor
)
target_link_libraries
(
conv2d_bwd_data_xdl PRIVATE host_tensor
)
target_link_libraries
(
conv2d_bwd_data_xdl PRIVATE host_tensor
)
...
...
reference_operation/include/reference_conv_backward_weight.hpp
View file @
a75b6800
#ifndef REFERENCE_CONV_
WRW
_HPP
#ifndef REFERENCE_CONV_
BWD_WGT
_HPP
#define REFERENCE_CONV_
WRW
_HPP
#define REFERENCE_CONV_
BWD_WGT
_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -17,7 +17,7 @@ template <typename InDataType,
...
@@ -17,7 +17,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
struct
ReferenceConv
Wrw
:
public
device
::
BaseOperator
struct
ReferenceConv
BwdWgt
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
...
@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
...
@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
// Invoker
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
struct
Invoker
:
public
device
::
BaseInvoker
{
{
using
Argument
=
ReferenceConv
Wrw
::
Argument
;
using
Argument
=
ReferenceConv
BwdWgt
::
Argument
;
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment