Commit c3943745 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add check for the 'VectorDim' & 'ScalarPerVector' template params

parent a399b408
...@@ -235,7 +235,27 @@ struct DevicePermute ...@@ -235,7 +235,27 @@ struct DevicePermute
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_); constexpr auto IsScalarPerVectorValid = [](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t vectorDim,
index_t scalarPerVector) {
if(strides[vectorDim] == 1 && lengths[vectorDim] % scalarPerVector == 0)
{
return true;
}
else if(strides[vectorDim] != 1 && scalarPerVector == 1)
{
return true;
}
return false;
};
return IsScalarPerVectorValid(
arg.inLengths_, arg.inStrides_, SrcVectorDim, SrcScalarPerVector) &&
IsScalarPerVectorValid(
arg.outLengths_, arg.outStrides_, DstVectorDim, DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
}; };
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment