Commit 6f341d80 authored by Chao Liu's avatar Chao Liu
Browse files

misc fixes; add 1x1 specialization

parent b8aeb85b
...@@ -160,7 +160,7 @@ enum InMemoryDataOperationEnum_t ...@@ -160,7 +160,7 @@ enum InMemoryDataOperationEnum_t
enum ActivTypeEnum_t enum ActivTypeEnum_t
{ {
None = 0, None,
LeakyRelu, LeakyRelu,
Sigmoid Sigmoid
}; };
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -19,25 +20,34 @@ using AddRelu = ck::tensor_operation::element_wise::AddRelu; ...@@ -19,25 +20,34 @@ using AddRelu = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd; static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off // clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2> DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>
// clang-format on // clang-format on
>; >;
......
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//#######################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#######################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#######################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#######################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& device_conv_instances)
{
using DeviceConvs = device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances;
const auto device_convs = DeviceConvs{};
ck::static_for<0, std::tuple_size_v<DeviceConvs>, 1>{}([&](auto i) {
using Conv = remove_cvref_t<decltype(std::get<i>(device_convs))>;
auto conv = Conv{};
device_conv_instances.push_back(std::make_unique<Conv>(conv));
});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& device_conv_instances)
{
using DeviceConvs = device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances;
const auto device_convs = DeviceConvs{};
ck::static_for<0, std::tuple_size_v<DeviceConvs>, 1>{}([&](auto i) {
using Conv = remove_cvref_t<decltype(std::get<i>(device_convs))>;
auto conv = Conv{};
device_conv_instances.push_back(std::make_unique<Conv>(conv));
});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
namespace ck {
namespace tensor_operation {
namespace device {
enum ConvolutionForwardSpecialization_t
{
Default,
Filter1x1Pad0,
Filter1x1Stride1Pad0,
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_fwd_bias_activation_add.hpp" #include "device_conv_fwd_bias_activation_add.hpp"
#include "convolution_forward_specialization.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -26,6 +27,7 @@ template < ...@@ -26,6 +27,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -128,97 +130,216 @@ struct ...@@ -128,97 +130,216 @@ struct
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
// A: input tensor if constexpr(ConvForwardSpecialization ==
const auto in_n_hi_wi_c_grid_desc = ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); {
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmk_grid_desc =
in_n_hi_wi_c_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_pad_transform(Wi, InLeftPadW, InRightPadW), in_gemmmraw_gemmk_grid_desc,
make_pass_through_transform(C)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_right_pad_transform(GemmMRaw, GemmMPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, // B: weight tensor
make_tuple( const auto wei_gemmn_gemmk_grid_desc =
make_pass_through_transform(N), make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_pass_through_transform(C)), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
const auto in_gemmk_gemmmraw_grid_desc = make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), // C: output tensor
make_merge_transform(make_tuple(N, Ho, Wo))), const auto out_gemmmraw_gemmn_grid_desc =
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
in_gemmk_gemmmraw_grid_desc, make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_pass_through_transform(GemmN)),
make_pass_through_transform(GemmMRaw)), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto bias_grid_desc_gemmm_gemmn =
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad), // C1: residual tensor: assume same layout as output tensor
make_pass_through_transform(GemmK1Number)), const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
// B: weight tensor out_gemmm_gemmn_grid_desc,
const auto wei_k_yxc_grid_desc = bias_grid_desc_gemmm_gemmn,
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); resi_grid_desc_gemmm_gemmn);
}
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( else if constexpr(ConvForwardSpecialization ==
wei_k_yxc_grid_desc, ConvolutionForwardSpecialization_t::Filter1x1Pad0)
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), {
make_tuple(Sequence<0>{}, Sequence<1>{}), // A: input tensor
make_tuple(Sequence<1>{}, Sequence<0>{})); const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc, const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), in_n_hi_wi_c_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(make_pass_through_transform(N),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
// C: output tensor make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto out_nhowo_k_grid_desc = make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( in_n_ho_wo_c_grid_desc,
out_nhowo_k_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc = const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), make_tuple(make_pass_through_transform(GemmK0),
make_pass_through_transform(GemmN)), make_right_pad_transform(GemmMRaw, GemmMPad),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn = // B: weight tensor
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); const auto wei_gemmn_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C));
// C1: residual tensor: assume same layout as output tensor
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmn_gemmk_grid_desc,
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
wei_gemmk0_gemmn_gemmk1_grid_desc, make_pass_through_transform(GemmN)),
out_gemmm_gemmn_grid_desc, make_tuple(Sequence<1>{}, Sequence<0>{}),
bias_grid_desc_gemmm_gemmn, make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
resi_grid_desc_gemmm_gemmn);
// C: output tensor
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn =
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
// C1: residual tensor: assume same layout as output tensor
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn);
}
else
{
// A: input tensor
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmMRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B: weight tensor
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
wei_k_yxc_grid_desc,
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmmraw_gemmn_grid_desc =
transform_tensor_descriptor(out_nhowo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn =
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
// C1: residual tensor: assume same layout as output tensor
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn,
resi_grid_desc_gemmm_gemmn);
}
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
...@@ -315,7 +436,11 @@ struct ...@@ -315,7 +436,11 @@ struct
N01_{N01}, N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op} out_element_op_{out_element_op},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
...@@ -383,6 +508,11 @@ struct ...@@ -383,6 +508,11 @@ struct
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_; OutElementwiseOperation out_element_op_;
// for checking IsSupportedArgument()
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -535,6 +665,30 @@ struct ...@@ -535,6 +665,30 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_fwd_bias_activation.hpp" #include "device_conv_fwd_bias_activation.hpp"
#include "convolution_forward_specialization.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -27,6 +28,7 @@ template < ...@@ -27,6 +28,7 @@ template <
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation, InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -127,93 +129,204 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -127,93 +129,204 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
// A: input tensor if constexpr(ConvForwardSpecialization ==
const auto in_n_hi_wi_c_grid_desc = ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); {
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmk_grid_desc =
in_n_hi_wi_c_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_pad_transform(Wi, InLeftPadW, InRightPadW), in_gemmmraw_gemmk_grid_desc,
make_pass_through_transform(C)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_right_pad_transform(GemmMRaw, GemmMPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, // B: weight tensor
make_tuple( const auto wei_gemmn_gemmk_grid_desc =
make_pass_through_transform(N), make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_pass_through_transform(C)), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
const auto in_gemmk_gemmmraw_grid_desc = make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), // C: output tensor
make_merge_transform(make_tuple(N, Ho, Wo))), const auto out_gemmmraw_gemmn_grid_desc =
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
in_gemmk_gemmmraw_grid_desc, make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_pass_through_transform(GemmN)),
make_pass_through_transform(GemmMRaw)), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto bias_grid_desc_gemmm_gemmn =
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad), return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
make_pass_through_transform(GemmK1Number)), wei_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), out_gemmm_gemmn_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); bias_grid_desc_gemmm_gemmn);
}
// B: weight tensor else if constexpr(ConvForwardSpecialization ==
const auto wei_k_yxc_grid_desc = ConvolutionForwardSpecialization_t::Filter1x1Pad0)
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); {
// A: input tensor
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( const auto in_n_hi_wi_c_grid_desc =
wei_k_yxc_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{})); in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
wei_gemmk_gemmn_grid_desc, make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_pass_through_transform(C)),
make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
// C: output tensor in_n_ho_wo_c_grid_desc,
const auto out_nhowo_k_grid_desc = make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
out_nhowo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{}), in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
const auto out_gemmm_gemmn_grid_desc = make_pass_through_transform(GemmK1Number)),
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), // B: weight tensor
make_tuple(Sequence<0>{}, Sequence<1>{})); const auto wei_gemmn_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn = const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); wei_gemmn_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, make_pass_through_transform(GemmN)),
wei_gemmk0_gemmn_gemmk1_grid_desc, make_tuple(Sequence<1>{}, Sequence<0>{}),
out_gemmm_gemmn_grid_desc, make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
bias_grid_desc_gemmm_gemmn);
// C: output tensor
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn =
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn);
}
else
{
// A: input tensor
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmMRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B: weight tensor
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
wei_k_yxc_grid_desc,
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmmraw_gemmn_grid_desc =
transform_tensor_descriptor(out_nhowo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C0: bias tensor: assume a contiguous vector
const auto bias_grid_desc_gemmm_gemmn =
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn);
}
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
...@@ -304,7 +417,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -304,7 +417,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
N01_{N01}, N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op} out_element_op_{out_element_op},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
...@@ -340,7 +457,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -340,7 +457,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} }
} }
// private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
...@@ -361,6 +477,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -361,6 +477,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_; OutElementwiseOperation out_element_op_;
// for checking IsSupportedArgument()
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -500,6 +621,30 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -500,6 +621,30 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_fwd.hpp" #include "device_conv_fwd.hpp"
#include "convolution_forward_specialization.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -25,6 +26,7 @@ template < ...@@ -25,6 +26,7 @@ template <
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -121,88 +123,189 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -121,88 +123,189 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
// A: input tensor if constexpr(ConvForwardSpecialization ==
const auto in_n_hi_wi_c_grid_desc = ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); {
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmk_grid_desc =
in_n_hi_wi_c_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_pad_transform(Wi, InLeftPadW, InRightPadW), in_gemmmraw_gemmk_grid_desc,
make_pass_through_transform(C)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_right_pad_transform(GemmMRaw, GemmMPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, // B: weight tensor
make_tuple( const auto wei_gemmn_gemmk_grid_desc =
make_pass_through_transform(N), make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_pass_through_transform(C)), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
const auto in_gemmk_gemmmraw_grid_desc = make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), // C: output tensor
make_merge_transform(make_tuple(N, Ho, Wo))), const auto out_gemmmraw_gemmn_grid_desc =
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
in_gemmk_gemmmraw_grid_desc, make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_pass_through_transform(GemmN)),
make_pass_through_transform(GemmMRaw)), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
const auto in_gemmk0_gemmm_gemmk1_grid_desc = wei_gemmk0_gemmn_gemmk1_grid_desc,
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc);
make_tuple(make_pass_through_transform(GemmK0), }
make_right_pad_transform(GemmMRaw, GemmMPad), else if constexpr(ConvForwardSpecialization ==
make_pass_through_transform(GemmK1Number)), ConvolutionForwardSpecialization_t::Filter1x1Pad0)
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), {
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); // A: input tensor
const auto in_n_hi_wi_c_grid_desc =
// B: weight tensor make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( make_tuple(make_pass_through_transform(N),
wei_k_yxc_grid_desc, make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc, const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), in_n_ho_wo_c_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc = const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( make_right_pad_transform(GemmMRaw, GemmMPad),
out_nhowo_k_grid_desc, make_pass_through_transform(GemmK1Number)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B: weight tensor
const auto out_gemmm_gemmn_grid_desc = const auto wei_gemmn_gemmk_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{}), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(Sequence<1>{}, Sequence<0>{}),
wei_gemmk0_gemmn_gemmk1_grid_desc, make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
out_gemmm_gemmn_grid_desc);
// C: output tensor
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
else
{
// A: input tensor
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmMRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B: weight tensor
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
wei_k_yxc_grid_desc,
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmmraw_gemmn_grid_desc =
transform_tensor_descriptor(out_nhowo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
...@@ -287,7 +390,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -287,7 +390,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
N01_{N01}, N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op} out_element_op_{out_element_op},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
...@@ -317,7 +424,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -317,7 +424,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} }
} }
// private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
...@@ -333,6 +439,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -333,6 +439,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_; OutElementwiseOperation out_element_op_;
// for checking IsSupportedArgument()
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -481,6 +592,30 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -481,6 +592,30 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_fwd.hpp" #include "device_conv_fwd.hpp"
#include "convolution_forward_specialization.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -24,6 +25,7 @@ template <typename InDataType, ...@@ -24,6 +25,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -116,88 +118,189 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -116,88 +118,189 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t GemmK0 = GemmK / GemmK1Number; const index_t GemmK0 = GemmK / GemmK1Number;
// A: input tensor if constexpr(ConvForwardSpecialization ==
const auto in_n_hi_wi_c_grid_desc = ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); {
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmk_grid_desc =
in_n_hi_wi_c_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH), const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_pad_transform(Wi, InLeftPadW, InRightPadW), in_gemmmraw_gemmk_grid_desc,
make_pass_through_transform(C)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_right_pad_transform(GemmMRaw, GemmMPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, // B: weight tensor
make_tuple( const auto wei_gemmn_gemmk_grid_desc =
make_pass_through_transform(N), make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_pass_through_transform(C)), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
const auto in_gemmk_gemmmraw_grid_desc = make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), // C: output tensor
make_merge_transform(make_tuple(N, Ho, Wo))), const auto out_gemmmraw_gemmn_grid_desc =
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
in_gemmk_gemmmraw_grid_desc, make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_pass_through_transform(GemmN)),
make_pass_through_transform(GemmMRaw)), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
const auto in_gemmk0_gemmm_gemmk1_grid_desc = wei_gemmk0_gemmn_gemmk1_grid_desc,
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc);
make_tuple(make_pass_through_transform(GemmK0), }
make_right_pad_transform(GemmMRaw, GemmMPad), else if constexpr(ConvForwardSpecialization ==
make_pass_through_transform(GemmK1Number)), ConvolutionForwardSpecialization_t::Filter1x1Pad0)
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), {
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); // A: input tensor
const auto in_n_hi_wi_c_grid_desc =
// B: weight tensor make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( make_tuple(make_pass_through_transform(N),
wei_k_yxc_grid_desc, make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc, const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), in_n_ho_wo_c_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc = const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( make_right_pad_transform(GemmMRaw, GemmMPad),
out_nhowo_k_grid_desc, make_pass_through_transform(GemmK1Number)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B: weight tensor
const auto out_gemmm_gemmn_grid_desc = const auto wei_gemmn_gemmk_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, make_naive_tensor_descriptor_packed(make_tuple(K, C));
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)), const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{}), wei_gemmn_gemmk_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(Sequence<1>{}, Sequence<0>{}),
wei_gemmk0_gemmn_gemmk1_grid_desc, make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
out_gemmm_gemmn_grid_desc);
// C: output tensor
const auto out_gemmmraw_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
else
{
// A: input tensor
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmmraw_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk_gemmmraw_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmMRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B: weight tensor
const auto wei_k_yxc_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
wei_k_yxc_grid_desc,
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_nhowo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmmraw_gemmn_grid_desc =
transform_tensor_descriptor(out_nhowo_k_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
...@@ -281,7 +384,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -281,7 +384,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
N01_{N01}, N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op} out_element_op_{out_element_op},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
...@@ -324,6 +431,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -324,6 +431,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_; OutElementwiseOperation out_element_op_;
// for checking IsSupportedArgument()
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -444,6 +556,30 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -444,6 +556,30 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
#define CK_DEVICE_OPERATION_INSTANCE_HPP
#include <stdlib.h>
namespace ck {
namespace tensor_operation {
namespace device {
template <typename OpInstance, typename NewOpInstances>
void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op_instances,
const NewOpInstances& new_op_instances)
{
ck::static_for<0, std::tuple_size_v<NewOpInstances>, 1>{}([&](auto i) {
const auto new_op_instance = std::get<i>(new_op_instances);
using NewOpInstance = remove_cvref_t<decltype(new_op_instance)>;
op_instances.push_back(std::make_unique<NewOpInstance>(new_op_instance));
});
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -30,14 +30,17 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -30,14 +30,17 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
using DeviceConvFwdInstance = using DeviceConvFwdInstance =
ck::tensor_operation::device::DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ck::tensor_operation::device::DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off // clang-format off
//##| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; <InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on // clang-format on
template <typename TIn, template <typename TIn,
......
...@@ -30,14 +30,17 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -30,14 +30,17 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off // clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| // | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| // | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; <InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
// clang-format on // clang-format on
template <typename TIn, template <typename TIn,
......
...@@ -32,15 +32,18 @@ using OutElementOp = ck::tensor_operation::element_wise::AddRelu; ...@@ -32,15 +32,18 @@ using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set; static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
// clang-format off // clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off // clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // | InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| // | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| // | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, MemorySet, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; <InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
// clang-format on // clang-format on
template <typename TIn, template <typename TIn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment