Commit 910a26b4 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Check scalar-per-vector with padded length

parent f3b3a61c
......@@ -232,13 +232,17 @@ struct DevicePermute
static bool IsSupportedArgument(const Argument& arg)
{
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
return math::integer_divide_ceil(length, tile_length) * tile_length;
};
constexpr auto IsScalarPerVectorValid =
[](index_t length, index_t stride, index_t scalarPerVector) {
if(stride == 1 && length % scalarPerVector == 0)
[](index_t length, index_t stride, index_t scalar_per_vector) {
if(stride == 1 && length % scalar_per_vector == 0)
{
return true;
}
else if(stride != 1 && scalarPerVector == 1)
else if(stride != 1 && scalar_per_vector == 1)
{
return true;
}
......@@ -249,9 +253,19 @@ struct DevicePermute
return IsScalarPerVectorValid(arg.inLengths_[SrcVectorDim],
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.inLengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.outLengths_[DstVectorDim],
arg.outStrides_[DstVectorDim],
DstScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.outLengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[DstVectorDim],
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment