Commit e5e7adbd authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Extract local types as template parameters

parent 5e28dcda
......@@ -82,8 +82,11 @@ template <typename InDataType,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW>
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
struct DevicePermute
: detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType,
ElementwiseOperation,
NumDim,
......@@ -91,7 +94,9 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW>>
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>>
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
......@@ -142,7 +147,9 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW>;
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>;
struct Argument : public BaseArgument
{
......
......@@ -99,7 +99,9 @@ template <typename InGridDesc,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW>
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
......@@ -203,8 +205,6 @@ struct GridwisePermute
static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize());
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>;
using ABlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t ABlockTransferSrcVectorDim = 2;
......@@ -222,8 +222,8 @@ struct GridwisePermute
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
ABlockTransferThreadClusterArrangeOrder,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
InDataType,
decltype(in_grid_desc_n_h_w),
......@@ -261,8 +261,8 @@ struct GridwisePermute
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
ABlockTransferThreadClusterArrangeOrder,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
OutDataType,
decltype(in_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment