Commit a22c7cf5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 0530fd66
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "convnd_fwd_common.hpp" #include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp"
using InDataType = ck::half_t; using InDataType = ck::half_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::half_t;
...@@ -67,48 +67,51 @@ static constexpr auto ConvSpec = ...@@ -67,48 +67,51 @@ static constexpr auto ConvSpec =
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMultipleD_Xdl_CShuffle<
ck::tensor_operation::device::DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle< NDimSpatial,
NDimSpatial, // ck::tensor_layout::convolution::NWC,
InDataType, // ck::tensor_layout::convolution::KXC,
WeiDataType, // ck::tensor_layout::convolution::NWK,
AccDataType, // ck::Tuple<>,
CShuffleDataType, // InDataType,
ck::Tuple<>, // WeiDataType,
OutDataType, // AccDataType,
InElementOp, // Input Elementwise Operation CShuffleDataType,
WeiElementOp, // Weights Elementwise Operation ck::Tuple<>,
OutElementOp, // Output Elementwise Operation OutDataType,
ConvSpec, // ConvForwardSpecialization InElementOp,
GemmSpec, // GemmSpecialization WeiElementOp,
1, // OutElementOp,
256, // BlockSize ConvSpec, // ConvForwardSpecialization
128, // MPerBlock GemmSpec, // GemmSpecialization
256, // NPerBlock 1, //
32, // KPerBlock 256, // BlockSize
8, // K1 128, // MPerBlock
32, // MPerXdl 256, // NPerBlock
32, // NPerXdl 32, // KPerBlock
2, // MXdlPerWave 8, // K1
4, // NXdlPerWave 32, // MPerXdl
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 32, // NPerXdl
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // MXdlPerWave
S<1, 0, 2>, // ABlockTransferSrcAccessOrder 4, // NXdlPerWave
2, // ABlockTransferSrcVectorDim S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
8, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
8, // ABlockTransferDstScalarPerVector_K1 S<1, 0, 2>, // ABlockTransferSrcAccessOrder
1, // ABlockLdsExtraM 2, // ABlockTransferSrcVectorDim
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 8, // ABlockTransferSrcScalarPerVector
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 8, // ABlockTransferDstScalarPerVector_K1
S<1, 0, 2>, // BBlockTransferSrcAccessOrder 1, // ABlockLdsExtraM
2, // BBlockTransferSrcVectorDim S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
8, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
8, // BBlockTransferDstScalarPerVector_K1 S<1, 0, 2>, // BBlockTransferSrcAccessOrder
1, // BBlockLdsExtraN 2, // BBlockTransferSrcVectorDim
1, 8, // BBlockTransferSrcScalarPerVector
1, 8, // BBlockTransferDstScalarPerVector_K1
S<1, 32, 1, 8>, 1, // BBlockLdsExtraN
8>; 1,
1,
S<1, 32, 1, 8>,
8>;
#endif #endif
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -23,7 +23,8 @@ namespace device { ...@@ -23,7 +23,8 @@ namespace device {
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DsLayout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Grouped Convolution Forword
// input : input image A[G, C, N, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::vector<ck::index_t>& a_g_n_c_wis_lengths,
const std::vector<ck::index_t>& a_g_n_c_wis_strides,
const std::vector<ck::index_t>& b_g_k_c_xs_lengths,
const std::vector<ck::index_t>& b_g_k_c_xs_strides,
std::array<std::vector<ck::index_t>, NumDTensor> ds_g_n_k_wos_lengths;
std::array<std::vector<ck::index_t>, NumDTensor> ds_g_n_k_wos_strides;
const std::vector<ck::index_t>& e_g_n_k_wos_lengths,
const std::vector<ck::index_t>& e_g_n_k_wos_strides,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -110,6 +109,10 @@ __global__ void ...@@ -110,6 +109,10 @@ __global__ void
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// //
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
...@@ -150,31 +153,21 @@ template <index_t NDimSpatial, ...@@ -150,31 +153,21 @@ template <index_t NDimSpatial,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
: public DeviceConvFwdMultipleD< : public DeviceGroupedConvFwdMultipleD<NDimSpatial,
NDimSpatial, ALayout,
ck::tuple_element_t<NDimSpatial - 1, BLayout,
ck::Tuple<ck::tensor_layout::convolution::NWC, DsLayout,
ck::tensor_layout::convolution::NHWC, ELayout,
ck::tensor_layout::convolution::NDHWC>>, ADataType,
ck::tuple_element_t<NDimSpatial - 1, BDataType,
ck::Tuple<ck::tensor_layout::convolution::KXC, DsDataType,
ck::tensor_layout::convolution::KYXC, EDataType,
ck::tensor_layout::convolution::KZYXC>>, AElementwiseOperation,
ck::tuple_element_t<NDimSpatial - 1, BElementwiseOperation,
ck::Tuple<ck::tensor_layout::convolution::NWK, CDEElementwiseOperation>
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle;
using DeviceOp = DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -189,6 +182,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -189,6 +182,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename std::enable_if<ALayout, bool>::type = false>
static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw) static auto GetWeightTensorDescriptor(index_t GemmNRaw, index_t GemmKRaw)
{ {
const auto wei_k_yxc_grid_desc = const auto wei_k_yxc_grid_desc =
...@@ -1076,7 +1070,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -1076,7 +1070,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle" str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment