Commit a75b6800 authored by ltqin's avatar ltqin
Browse files

change wrw to bwd wgt

parent e17c0d80
#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#ifndef DEVICE_CONV2D_BWD_WGT_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_BWD_WGT_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <sstream>
......@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
struct DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWgt<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using DeviceOp =
DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -691,7 +694,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
#ifndef DEVICE_CONV_WRW_HPP
#define DEVICE_CONV_WRW_HPP
#ifndef DEVICE_CONV_BWD_WGT_HPP
#define DEVICE_CONV_BWD_WGT_HPP
#include <iostream>
#include "device_base.hpp"
......@@ -11,7 +11,7 @@ namespace device {
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator
struct DeviceConvBwdWgt : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
......@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
using DeviceConvBwdWgtPtr = std::unique_ptr<
DeviceConvBwdWgt<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
# Instructions for ```conv2d_wrw_xdl``` Example
# Instructions for ```conv2d_bwd_wgt_xdl``` Example
## Docker script
```bash
......@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```conv2d_wrw_xdl```
## Build ```conv2d_bwd_wgt_xdl```
```bash
mkdir build && cd build
```
......@@ -30,17 +30,17 @@ cmake \
```
```bash
make -j conv2d_wrw_xdl
make -j conv2d_bwd_wgt_xdl
```
## Run ```conv2d_wrw_xdl```
## Run ```conv2d_bwd_wgt_xdl```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: is show log (0=no, 1=yes)
#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k
./example/conv2d_fwd_xdl 0 1 5 0 4
./example/conv2d_bwd_wgt_xdl 0 1 5 0 4
```
Result
......
......@@ -32,8 +32,8 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
using DeviceConvWrWInstance = ck::tensor_operation::device::
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
using DeviceConvBwdWgtInstance = ck::tensor_operation::device::
DeviceConv2dBwdWgtXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
......@@ -70,8 +70,12 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceConvWrwInstance = ck::tensor_operation::host::
ReferenceConvWrw<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>;
using ReferenceConvBwdWgtInstance = ck::tensor_operation::host::ReferenceConvBwdWgt<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[])
{
......@@ -211,7 +215,7 @@ int main(int argc, char* argv[])
wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data());
// do GEMM
auto conv = DeviceConvWrWInstance{};
auto conv = DeviceConvBwdWgtInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
......@@ -256,7 +260,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
auto ref_conv = ReferenceConvWrwInstance{};
auto ref_conv = ReferenceConvBwdWgtInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
......
......@@ -24,7 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw
set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp)
set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp)
set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp)
set(CONV2D_WRW_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp)
set(CONV2D_BWD_WGT_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp)
set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp)
set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp)
set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp)
......@@ -42,7 +42,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC
add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE})
add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE})
add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE})
add_executable(conv2d_wrw_xdl ${CONV2D_WRW_XDL_SOURCE})
add_executable(conv2d_bwd_wgt_xdl ${CONV2D_BWD_WGT_XDL_SOURCE})
add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE})
add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE})
......@@ -60,7 +60,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor)
target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor)
target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor)
target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor)
target_link_libraries(conv2d_wrw_xdl PRIVATE host_tensor)
target_link_libraries(conv2d_bwd_wgt_xdl PRIVATE host_tensor)
target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor)
target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor)
target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor)
......
#ifndef REFERENCE_CONV_WRW_HPP
#define REFERENCE_CONV_WRW_HPP
#ifndef REFERENCE_CONV_BWD_WGT_HPP
#define REFERENCE_CONV_BWD_WGT_HPP
#include <iostream>
#include <sstream>
......@@ -17,7 +17,7 @@ template <typename InDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct ReferenceConvWrw : public device::BaseOperator
struct ReferenceConvBwdWgt : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
......@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvWrw::Argument;
using Argument = ReferenceConvBwdWgt::Argument;
float Run(const Argument& arg)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment