Commit 8b98d7d2 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add support to permute multiple elements together

parent 5f50ed89
......@@ -2,8 +2,8 @@ add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp32 permute_1xHxW_fp32.cpp)
add_example_executable(example_permute_NxHxW_fp32 permute_NxHxW_fp32.cpp)
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
add_example_executable(example_permute_HxWx2_fp16 permute_HxWx2_fp16.cpp)
add_dependencies(example_permute example_permute_1xHxW_fp32)
add_dependencies(example_permute example_permute_NxHxW_fp32)
add_dependencies(example_permute example_permute_HxWx4_fp16)
add_dependencies(example_permute example_permute_HxWx2_fp16)
......@@ -6,6 +6,7 @@
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <numeric>
......@@ -33,7 +34,9 @@ struct ExecutionConfig final
struct Problem final
{
using Shape = std::array<std::size_t, 3>;
static constexpr std::size_t NumDim = 3;
using Shape = std::array<std::size_t, NumDim>;
using Axes = Shape;
Problem() = delete;
......@@ -249,7 +252,7 @@ is_valid_axes(const Axes& axes)
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 3 + 3;
constexpr int num_problem_args = 2 * Problem::NumDim;
if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
{
......@@ -412,7 +415,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
using std::size;
if(!(is_valid_axes(axes) && size(axes) == 3))
if(!is_valid_axes(axes))
{
return false;
}
......@@ -420,7 +423,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(!(size(shape) == 3 && size(transposed_shape) == 3))
if(size(shape) != size(transposed_shape))
{
return false;
}
......@@ -437,18 +440,34 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
}
std::array<std::size_t, 3> indices{};
std::vector<std::size_t> indices(size(shape), 0);
if(!is_valid_indices(shape, indices))
{
return false;
}
if(size(shape) == 3)
{
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
} while(advance_indices(shape, indices));
}
else if(size(shape) == 4)
{
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
} while(advance_indices(shape, indices));
}
else
{
return false;
}
return true;
}
......@@ -3,8 +3,8 @@
#include "common.hpp"
using ADataType = F64;
using BDataType = F64;
using ADataType = F32;
using BDataType = F32;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
......@@ -15,10 +15,7 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
< ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on
#define NUM_ELEMS_IN_BUNDLE 4
#define NUM_ELEMS_IN_BUNDLE 2
#include "run_permute_example.inc"
int main(int argc, char* argv[])
{
return !run_permute_example(argc, argv, {1, 160, 80}, {0, 2, 1});
}
int main(int argc, char* argv[]) { return !run_permute_example(argc, argv, {1, 3, 4}, {0, 2, 1}); }
......@@ -9,6 +9,10 @@
bool run_permute(const ExecutionConfig& config, const Problem& problem)
{
#if 1 < NUM_ELEMS_IN_BUNDLE
static_assert(std::is_same_v<ADataType, BDataType>);
#endif
using std::begin, std::end;
const auto& shape = problem.shape;
......@@ -61,13 +65,53 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
if(config.do_verification)
{
Tensor<BDataType> host_b(transposed_shape);
host_permute(a, problem.axes, PassThrough{}, host_b);
b_device_buf.FromDevice(data(b.mData));
#if NUM_ELEMS_IN_BUNDLE == 1
Tensor<BDataType> host_b(transposed_shape);
if(!host_permute(a, problem.axes, PassThrough{}, host_b))
{
return false;
}
return ck::utils::check_err(
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10);
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
std::array<std::size_t, Problem::NumDim + 1> extended_shape;
std::copy(begin(shape), end(shape), begin(extended_shape));
extended_shape.back() = NUM_ELEMS_IN_BUNDLE;
using DataType = detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>;
Tensor<DataType> extended_a(extended_shape);
std::memcpy(data(extended_a.mData),
data(a.mData),
sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
std::array<std::size_t, Problem::NumDim + 1> extended_axes;
std::copy(begin(problem.axes), end(problem.axes), begin(extended_axes));
extended_axes.back() = Problem::NumDim;
std::array<std::size_t, Problem::NumDim + 1> transposed_extended_shape;
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
{
return false;
}
std::vector<DataType> extended_b(reinterpret_cast<DataType*>(data(b.mData)),
reinterpret_cast<DataType*>(data(b.mData)) +
b.mDesc.GetElementSpaceSize() * NUM_ELEMS_IN_BUNDLE);
return ck::utils::check_err(extended_b,
extended_host_b.mData,
"Error: incorrect results in output tensor",
1e-10,
1e-10);
#endif
}
return true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment