Commit 8b98d7d2 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add support to permute multiple elements together

parent 5f50ed89
...@@ -2,8 +2,8 @@ add_custom_target(example_permute) ...@@ -2,8 +2,8 @@ add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp32 permute_1xHxW_fp32.cpp) add_example_executable(example_permute_1xHxW_fp32 permute_1xHxW_fp32.cpp)
add_example_executable(example_permute_NxHxW_fp32 permute_NxHxW_fp32.cpp) add_example_executable(example_permute_NxHxW_fp32 permute_NxHxW_fp32.cpp)
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) add_example_executable(example_permute_HxWx2_fp16 permute_HxWx2_fp16.cpp)
add_dependencies(example_permute example_permute_1xHxW_fp32) add_dependencies(example_permute example_permute_1xHxW_fp32)
add_dependencies(example_permute example_permute_NxHxW_fp32) add_dependencies(example_permute example_permute_NxHxW_fp32)
add_dependencies(example_permute example_permute_HxWx4_fp16) add_dependencies(example_permute example_permute_HxWx2_fp16)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdlib> #include <cstdlib>
#include <cstring>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
...@@ -33,7 +34,9 @@ struct ExecutionConfig final ...@@ -33,7 +34,9 @@ struct ExecutionConfig final
struct Problem final struct Problem final
{ {
using Shape = std::array<std::size_t, 3>; static constexpr std::size_t NumDim = 3;
using Shape = std::array<std::size_t, NumDim>;
using Axes = Shape; using Axes = Shape;
Problem() = delete; Problem() = delete;
...@@ -249,7 +252,7 @@ is_valid_axes(const Axes& axes) ...@@ -249,7 +252,7 @@ is_valid_axes(const Axes& axes)
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{ {
constexpr int num_execution_config_args = 2; constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 3 + 3; constexpr int num_problem_args = 2 * Problem::NumDim;
if(!(num_problem_args == size(problem.shape) + size(problem.axes))) if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
{ {
...@@ -412,7 +415,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -412,7 +415,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
} }
using std::size; using std::size;
if(!(is_valid_axes(axes) && size(axes) == 3)) if(!is_valid_axes(axes))
{ {
return false; return false;
} }
...@@ -420,7 +423,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -420,7 +423,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> && static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>); detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(!(size(shape) == 3 && size(transposed_shape) == 3)) if(size(shape) != size(transposed_shape))
{ {
return false; return false;
} }
...@@ -437,18 +440,34 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -437,18 +440,34 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
} }
} }
std::array<std::size_t, 3> indices{}; std::vector<std::size_t> indices(size(shape), 0);
if(!is_valid_indices(shape, indices)) if(!is_valid_indices(shape, indices))
{ {
return false; return false;
} }
if(size(shape) == 3)
{
do do
{ {
Dest output = 0; Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2])); functor(output, src(indices[0], indices[1], indices[2]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output; dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
} while(advance_indices(shape, indices)); } while(advance_indices(shape, indices));
}
else if(size(shape) == 4)
{
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
} while(advance_indices(shape, indices));
}
else
{
return false;
}
return true; return true;
} }
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "common.hpp" #include "common.hpp"
using ADataType = F64; using ADataType = F32;
using BDataType = F64; using BDataType = F32;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
...@@ -15,10 +15,7 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute ...@@ -15,10 +15,7 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
< ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>; < ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on // clang-format on
#define NUM_ELEMS_IN_BUNDLE 4 #define NUM_ELEMS_IN_BUNDLE 2
#include "run_permute_example.inc" #include "run_permute_example.inc"
int main(int argc, char* argv[]) int main(int argc, char* argv[]) { return !run_permute_example(argc, argv, {1, 3, 4}, {0, 2, 1}); }
{
return !run_permute_example(argc, argv, {1, 160, 80}, {0, 2, 1});
}
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
bool run_permute(const ExecutionConfig& config, const Problem& problem) bool run_permute(const ExecutionConfig& config, const Problem& problem)
{ {
#if 1 < NUM_ELEMS_IN_BUNDLE
static_assert(std::is_same_v<ADataType, BDataType>);
#endif
using std::begin, std::end; using std::begin, std::end;
const auto& shape = problem.shape; const auto& shape = problem.shape;
...@@ -61,13 +65,53 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -61,13 +65,53 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
if(config.do_verification) if(config.do_verification)
{ {
Tensor<BDataType> host_b(transposed_shape);
host_permute(a, problem.axes, PassThrough{}, host_b);
b_device_buf.FromDevice(data(b.mData)); b_device_buf.FromDevice(data(b.mData));
#if NUM_ELEMS_IN_BUNDLE == 1
Tensor<BDataType> host_b(transposed_shape);
if(!host_permute(a, problem.axes, PassThrough{}, host_b))
{
return false;
}
return ck::utils::check_err( return ck::utils::check_err(
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10); b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10);
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
std::array<std::size_t, Problem::NumDim + 1> extended_shape;
std::copy(begin(shape), end(shape), begin(extended_shape));
extended_shape.back() = NUM_ELEMS_IN_BUNDLE;
using DataType = detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>;
Tensor<DataType> extended_a(extended_shape);
std::memcpy(data(extended_a.mData),
data(a.mData),
sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
std::array<std::size_t, Problem::NumDim + 1> extended_axes;
std::copy(begin(problem.axes), end(problem.axes), begin(extended_axes));
extended_axes.back() = Problem::NumDim;
std::array<std::size_t, Problem::NumDim + 1> transposed_extended_shape;
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
{
return false;
}
std::vector<DataType> extended_b(reinterpret_cast<DataType*>(data(b.mData)),
reinterpret_cast<DataType*>(data(b.mData)) +
b.mDesc.GetElementSpaceSize() * NUM_ELEMS_IN_BUNDLE);
return ck::utils::check_err(extended_b,
extended_host_b.mData,
"Error: incorrect results in output tensor",
1e-10,
1e-10);
#endif
} }
return true; return true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment