Commit a70d9f63 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add more template parameters (vector width related)

parent 4eaa502b
......@@ -84,7 +84,11 @@ template <typename InDataType,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermute
: detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType,
......@@ -96,7 +100,11 @@ struct DevicePermute
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>>
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
......@@ -149,7 +157,11 @@ struct DevicePermute
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder>;
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>;
struct Argument : public BaseArgument
{
......
......@@ -101,7 +101,11 @@ template <typename InGridDesc,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder>
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
......@@ -207,10 +211,10 @@ struct GridwisePermute
using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t SrcVectorDim = 2;
constexpr index_t DstVectorDim = 1;
constexpr index_t SrcScalarPerVector = 1;
constexpr index_t DstScalarPerVector = 1;
// constexpr index_t SrcVectorDim = 2;
// constexpr index_t DstVectorDim = 1;
// constexpr index_t SrcScalarPerVector = 1;
// constexpr index_t DstScalarPerVector = 1;
using ck::tensor_operation::element_wise::PassThrough;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment