Commit 33aa4d45 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Calculate new SrcVectorDim/DstVectorDim after merge descriptor dimensions

parent a70d9f63
...@@ -107,6 +107,8 @@ struct DevicePermute ...@@ -107,6 +107,8 @@ struct DevicePermute
DstScalarPerVector>> DstScalarPerVector>>
{ {
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
template <index_t N = NumDim> template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array) static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
...@@ -146,22 +148,23 @@ struct DevicePermute ...@@ -146,22 +148,23 @@ struct DevicePermute
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGridDesc = InGridDesc; using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<InGridDesc, using GridwisePermute = GridwisePermute<
OutGridDesc, InGridDesc,
InDataType, OutGridDesc,
OutDataType, InDataType,
ElementwiseOperation, OutDataType,
BlockSize, ElementwiseOperation,
NPerBlock, BlockSize,
HPerBlock, NPerBlock,
WPerBlock, HPerBlock,
InBlockLdsExtraW, WPerBlock,
InBlockTransferThreadClusterLengths, InBlockLdsExtraW,
InBlockTransferThreadClusterArrangeOrder, InBlockTransferThreadClusterLengths,
SrcVectorDim, InBlockTransferThreadClusterArrangeOrder,
DstVectorDim, SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
SrcScalarPerVector, DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
DstScalarPerVector>; SrcScalarPerVector,
DstScalarPerVector>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
......
...@@ -110,6 +110,10 @@ struct GridwisePermute ...@@ -110,6 +110,10 @@ struct GridwisePermute
{ {
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension()); static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
static_assert(3 <= InGridDesc::GetNumOfDimension()); static_assert(3 <= InGridDesc::GetNumOfDimension());
static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
SrcVectorDim < InGridDesc::GetNumOfDimension());
static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
DstVectorDim < OutGridDesc::GetNumOfDimension());
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -211,10 +215,10 @@ struct GridwisePermute ...@@ -211,10 +215,10 @@ struct GridwisePermute
using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>; using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>; using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
// constexpr index_t SrcVectorDim = 2; constexpr index_t SrcVectorDimAfterMerge =
// constexpr index_t DstVectorDim = 1; SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
// constexpr index_t SrcScalarPerVector = 1; constexpr index_t DstVectorDimAfterMerge =
// constexpr index_t DstScalarPerVector = 1; DstVectorDim - (OutGridDesc::GetNumOfDimension() - 3);
using ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::element_wise::PassThrough;
...@@ -234,8 +238,8 @@ struct GridwisePermute ...@@ -234,8 +238,8 @@ struct GridwisePermute
decltype(in_block_desc), decltype(in_block_desc),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
SrcVectorDim, SrcVectorDimAfterMerge,
SrcVectorDim, SrcVectorDimAfterMerge,
SrcScalarPerVector, SrcScalarPerVector,
SrcScalarPerVector, SrcScalarPerVector,
1, 1,
...@@ -273,8 +277,8 @@ struct GridwisePermute ...@@ -273,8 +277,8 @@ struct GridwisePermute
decltype(out_grid_desc_n_h_w), decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
SrcVectorDim, SrcVectorDimAfterMerge,
DstVectorDim, DstVectorDimAfterMerge,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment