Commit 5e28dcda authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add GridwisePermute::CheckValidity()

parent 0c23d6fa
......@@ -212,17 +212,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static bool IsSupportedArgument(const Argument& arg)
{
// check if only swap last 2 dimensions
if(!(std::equal(begin(arg.inLengths_),
std::prev(end(arg.inLengths_), 2),
begin(arg.outLengths_)) &&
std::tie(*rbegin(arg.inLengths_), *std::next(rbegin(arg.inLengths_))) ==
std::tie(*std::next(rbegin(arg.outLengths_)), *rbegin(arg.outLengths_))))
{
return false;
}
return true;
return GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
};
......
......@@ -152,6 +152,26 @@ struct GridwisePermute
return DefaultBlock2TileMap{desc};
}
__host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
const OutGridDesc& out_grid_desc)
{
constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
bool valid = true;
static_for<0, NumDim - 2, 1>{}([&](auto I) {
if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
{
valid = false;
}
});
return valid &&
(in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
(in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
out_grid_desc.GetLength(Number<NumDim - 1>{}));
}
template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment