Commit a75b6800 authored by ltqin's avatar ltqin
Browse files

change wrw to bwd wgt

parent e17c0d80
#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP #ifndef DEVICE_CONV2D_BWD_WGT_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP #define DEVICE_CONV2D_BWD_WGT_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -52,10 +52,13 @@ template <typename InDataType, ...@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvBwdWgt<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; using DeviceOp =
DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -691,7 +694,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -691,7 +694,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" str << "DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
#ifndef DEVICE_CONV_WRW_HPP #ifndef DEVICE_CONV_BWD_WGT_HPP
#define DEVICE_CONV_WRW_HPP #define DEVICE_CONV_BWD_WGT_HPP
#include <iostream> #include <iostream>
#include "device_base.hpp" #include "device_base.hpp"
...@@ -11,7 +11,7 @@ namespace device { ...@@ -11,7 +11,7 @@ namespace device {
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator struct DeviceConvBwdWgt : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator ...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr< using DeviceConvBwdWgtPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvBwdWgt<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
# Instructions for ```conv2d_wrw_xdl``` Example # Instructions for ```conv2d_bwd_wgt_xdl``` Example
## Docker script ## Docker script
```bash ```bash
...@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \ ...@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash /bin/bash
``` ```
## Build ```conv2d_wrw_xdl``` ## Build ```conv2d_bwd_wgt_xdl```
```bash ```bash
mkdir build && cd build mkdir build && cd build
``` ```
...@@ -30,17 +30,17 @@ cmake \ ...@@ -30,17 +30,17 @@ cmake \
``` ```
```bash ```bash
make -j conv2d_wrw_xdl make -j conv2d_bwd_wgt_xdl
``` ```
## Run ```conv2d_wrw_xdl``` ## Run ```conv2d_bwd_wgt_xdl```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
#arg4: is show log (0=no, 1=yes) #arg4: is show log (0=no, 1=yes)
#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k #arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k
./example/conv2d_fwd_xdl 0 1 5 0 4 ./example/conv2d_bwd_wgt_xdl 0 1 5 0 4
``` ```
Result Result
......
...@@ -32,8 +32,8 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -32,8 +32,8 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
// clang-format off // clang-format off
using DeviceConvWrWInstance = ck::tensor_operation::device:: using DeviceConvBwdWgtInstance = ck::tensor_operation::device::
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
OutDataType, // OutDataType OutDataType, // OutDataType
...@@ -70,8 +70,12 @@ using DeviceConvWrWInstance = ck::tensor_operation::device:: ...@@ -70,8 +70,12 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
using ReferenceConvWrwInstance = ck::tensor_operation::host:: using ReferenceConvBwdWgtInstance = ck::tensor_operation::host::ReferenceConvBwdWgt<InDataType,
ReferenceConvWrw<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>; WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -211,7 +215,7 @@ int main(int argc, char* argv[]) ...@@ -211,7 +215,7 @@ int main(int argc, char* argv[])
wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data());
// do GEMM // do GEMM
auto conv = DeviceConvWrWInstance{}; auto conv = DeviceConvBwdWgtInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
...@@ -256,7 +260,7 @@ int main(int argc, char* argv[]) ...@@ -256,7 +260,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
auto ref_conv = ReferenceConvWrwInstance{}; auto ref_conv = ReferenceConvBwdWgtInstance{};
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
......
...@@ -24,7 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw ...@@ -24,7 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw
set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp)
set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp)
set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp) set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp)
set(CONV2D_WRW_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp) set(CONV2D_BWD_WGT_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp)
set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp) set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp)
set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp) set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp)
set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp) set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp)
...@@ -42,7 +42,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC ...@@ -42,7 +42,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC
add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE})
add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE})
add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE}) add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE})
add_executable(conv2d_wrw_xdl ${CONV2D_WRW_XDL_SOURCE}) add_executable(conv2d_bwd_wgt_xdl ${CONV2D_BWD_WGT_XDL_SOURCE})
add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE}) add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE})
add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE}) add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE})
...@@ -60,7 +60,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) ...@@ -60,7 +60,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor)
target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor)
target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor) target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor)
target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor)
target_link_libraries(conv2d_wrw_xdl PRIVATE host_tensor) target_link_libraries(conv2d_bwd_wgt_xdl PRIVATE host_tensor)
target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor)
target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor)
target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor) target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor)
......
#ifndef REFERENCE_CONV_WRW_HPP #ifndef REFERENCE_CONV_BWD_WGT_HPP
#define REFERENCE_CONV_WRW_HPP #define REFERENCE_CONV_BWD_WGT_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -17,7 +17,7 @@ template <typename InDataType, ...@@ -17,7 +17,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct ReferenceConvWrw : public device::BaseOperator struct ReferenceConvBwdWgt : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator ...@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
// Invoker // Invoker
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
using Argument = ReferenceConvWrw::Argument; using Argument = ReferenceConvBwdWgt::Argument;
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment