Commit 5e28dcda authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add GridwisePermute::CheckValidity()

parent 0c23d6fa
...@@ -212,17 +212,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -212,17 +212,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check if only swap last 2 dimensions return GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
if(!(std::equal(begin(arg.inLengths_),
std::prev(end(arg.inLengths_), 2),
begin(arg.outLengths_)) &&
std::tie(*rbegin(arg.inLengths_), *std::next(rbegin(arg.inLengths_))) ==
std::tie(*std::next(rbegin(arg.outLengths_)), *rbegin(arg.outLengths_))))
{
return false;
}
return true;
}; };
}; };
......
...@@ -152,6 +152,26 @@ struct GridwisePermute ...@@ -152,6 +152,26 @@ struct GridwisePermute
return DefaultBlock2TileMap{desc}; return DefaultBlock2TileMap{desc};
} }
__host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
const OutGridDesc& out_grid_desc)
{
constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
bool valid = true;
static_for<0, NumDim - 2, 1>{}([&](auto I) {
if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
{
valid = false;
}
});
return valid &&
(in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
(in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
out_grid_desc.GetLength(Number<NumDim - 1>{}));
}
template <typename Block2TileMap> template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc, __device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc, const OutGridDesc out_grid_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment