Commit 33aa4d45 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Calculate new SrcVectorDim/DstVectorDim after merge descriptor dimensions

parent a70d9f63
......@@ -107,6 +107,8 @@ struct DevicePermute
DstScalarPerVector>>
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
......@@ -146,22 +148,23 @@ struct DevicePermute
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>;
using GridwisePermute = GridwisePermute<
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector,
DstScalarPerVector>;
struct Argument : public BaseArgument
{
......
......@@ -110,6 +110,10 @@ struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
static_assert(3 <= InGridDesc::GetNumOfDimension());
static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
SrcVectorDim < InGridDesc::GetNumOfDimension());
static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
DstVectorDim < OutGridDesc::GetNumOfDimension());
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -211,10 +215,10 @@ struct GridwisePermute
using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
// constexpr index_t SrcVectorDim = 2;
// constexpr index_t DstVectorDim = 1;
// constexpr index_t SrcScalarPerVector = 1;
// constexpr index_t DstScalarPerVector = 1;
constexpr index_t SrcVectorDimAfterMerge =
SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
constexpr index_t DstVectorDimAfterMerge =
DstVectorDim - (OutGridDesc::GetNumOfDimension() - 3);
using ck::tensor_operation::element_wise::PassThrough;
......@@ -234,8 +238,8 @@ struct GridwisePermute
decltype(in_block_desc),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
SrcVectorDim,
SrcVectorDim,
SrcVectorDimAfterMerge,
SrcVectorDimAfterMerge,
SrcScalarPerVector,
SrcScalarPerVector,
1,
......@@ -273,8 +277,8 @@ struct GridwisePermute
decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcVectorDimAfterMerge,
DstVectorDimAfterMerge,
SrcScalarPerVector,
DstScalarPerVector,
1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment