"...composable_kernel_rocm.git" did not exist on "6368be50c5f07afee4d0a164bf7fdb4210884708"
Commit e5e7adbd authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Extract local types as template parameters

parent 5e28dcda
......@@ -82,16 +82,21 @@ template <typename InDataType,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW>
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType,
ElementwiseOperation,
NumDim,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW>>
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
struct DevicePermute
: detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType,
ElementwiseOperation,
NumDim,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>>
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
......@@ -142,7 +147,9 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW>;
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>;
struct Argument : public BaseArgument
{
......
......@@ -99,7 +99,9 @@ template <typename InGridDesc,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW>
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
......@@ -202,10 +204,8 @@ struct GridwisePermute
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize());
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>;
using ABlockTransferAccessOrder = Sequence<0, 1, 2>;
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t ABlockTransferSrcVectorDim = 2;
constexpr index_t ABlockTransferDstVectorDim = 1;
......@@ -222,8 +222,8 @@ struct GridwisePermute
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
ABlockTransferThreadClusterArrangeOrder,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
InDataType,
decltype(in_grid_desc_n_h_w),
......@@ -261,8 +261,8 @@ struct GridwisePermute
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
ABlockTransferThreadClusterArrangeOrder,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
OutDataType,
decltype(in_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment