Commit 2e5d4f91 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Check if input/output shape meet the requirement

parent b41e6019
...@@ -31,8 +31,8 @@ struct ExecutionConfig final ...@@ -31,8 +31,8 @@ struct ExecutionConfig final
struct Problem final struct Problem final
{ {
std::array<std::size_t, 4> shape = {4, 16, 32, 32}; std::array<std::size_t, 4> shape = {4, 8, 16, 32};
std::array<std::size_t, 4> axes = {0, 2, 3, 1}; std::array<std::size_t, 4> axes = {0, 1, 3, 2};
}; };
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -21,23 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -21,23 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> a_lengths; std::array<ck::index_t, 4> a_lengths, b_lengths;
std::array<ck::index_t, 4> axes;
std::array<ck::index_t, 4> a_strides, b_strides; std::array<ck::index_t, 4> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer(); const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer(); void* output = b_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(a_lengths)); std::copy(begin(shape), end(shape), begin(a_lengths));
std::copy(begin(problem.axes), end(problem.axes), begin(axes));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides)); std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides)); std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{}; auto permute = DevicePermuteInstance{};
auto argument = auto argument = permute.MakeArgument(
permute.MakeArgument(a_lengths, axes, a_strides, b_strides, input, output, PassThrough{}); a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
if(!permute.IsSupportedArgument(argument)) if(!permute.IsSupportedArgument(argument))
{ {
......
...@@ -68,6 +68,9 @@ struct InvokerBase : BaseInvoker ...@@ -68,6 +68,9 @@ struct InvokerBase : BaseInvoker
}; };
} // namespace detail } // namespace detail
// Swap last 2 dimensions
// input: [d0, d1, d2, ..., d, dn-2, dn-1]
// output: [d0, d1, d2, ..., d, dn-1, dn-2]
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
...@@ -83,6 +86,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -83,6 +86,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
InScalarPerVector, InScalarPerVector,
OutScalarPerVector>> OutScalarPerVector>>
{ {
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
using InDataTypePointer = const InDataType*; using InDataTypePointer = const InDataType*;
using OutDataTypePointer = OutDataType*; using OutDataTypePointer = OutDataType*;
...@@ -155,8 +160,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -155,8 +160,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, NumDim> inLengths, Argument(const std::array<index_t, NumDim> inLengths,
const std::array<index_t, NumDim> axes,
const std::array<index_t, NumDim> inStrides, const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides, const std::array<index_t, NumDim> outStrides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
...@@ -168,8 +173,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -168,8 +173,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
in_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)), in_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)),
out_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)), out_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)),
inLengths_(inLengths), inLengths_(inLengths),
axes_(axes),
inStrides_(inStrides), inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides), outStrides_(outStrides),
elementwise_op_(elementwise_op) elementwise_op_(elementwise_op)
{ {
...@@ -184,8 +189,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -184,8 +189,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutGrid1dDesc out_grid_1d_desc_; OutGrid1dDesc out_grid_1d_desc_;
std::array<index_t, NumDim> inLengths_; std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> axes_;
std::array<index_t, NumDim> inStrides_; std::array<index_t, NumDim> inStrides_;
std::array<index_t, NumDim> outLengths_;
std::array<index_t, NumDim> outStrides_; std::array<index_t, NumDim> outStrides_;
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
...@@ -223,6 +228,16 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -223,6 +228,16 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
return false; return false;
} }
// check if only swap last 2 dimensions
if(!(std::equal(begin(arg.inLengths_),
std::prev(end(arg.inLengths_), 2),
begin(arg.outLengths_)) &&
std::tie(*rbegin(arg.inLengths_), *std::next(rbegin(arg.inLengths_))) ==
std::tie(*std::next(rbegin(arg.outLengths_)), *rbegin(arg.outLengths_))))
{
return false;
}
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides, const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) { index_t scalarPerVector) {
...@@ -241,7 +256,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -241,7 +256,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
valid = false; valid = false;
} }
if(!IsScalarPerVectorValid(arg.inLengths_, arg.outStrides_, OutScalarPerVector)) if(!IsScalarPerVectorValid(arg.outLengths_, arg.outStrides_, OutScalarPerVector))
{ {
valid = false; valid = false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment