Commit 3e605990 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use helper functions to simplify example code

parent a41132b5
......@@ -87,6 +87,29 @@ struct get_bundled<F32, 2>
template <typename Bundle, std::size_t Divisor>
using get_bundled_t = typename get_bundled<Bundle, Divisor>::type;
template <typename Array, std::size_t Difference>
struct enlarge_array_size;
template <typename T, std::size_t Size, std::size_t Difference>
struct enlarge_array_size<std::array<T, Size>, Difference>
{
using type = std::array<T, Size + Difference>;
};
template <typename Array, std::size_t Difference>
using enlarge_array_size_t = typename enlarge_array_size<Array, Difference>::type;
template <typename Array>
struct get_array_size;
template <typename T, std::size_t Size>
struct get_array_size<std::array<T, Size>> : std::integral_constant<std::size_t, Size>
{
};
template <typename Array>
inline constexpr std::size_t get_array_size_v = get_array_size<Array>::value;
template <typename T, typename = void>
struct is_iterator : std::false_type
{
......@@ -371,6 +394,30 @@ transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
return iter;
}
auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
{
detail::enlarge_array_size_t<Problem::Shape, 1> extended_shape;
using std::begin, std::end;
std::copy(begin(shape), end(shape), begin(extended_shape));
extended_shape.back() = new_dim;
return extended_shape;
}
auto extend_axes(const Problem::Axes& axes)
{
detail::enlarge_array_size_t<Problem::Axes, 1> extended_axes;
using std::begin, std::end;
std::copy(begin(axes), end(axes), begin(extended_axes));
extended_axes.back() = detail::get_array_size_v<Problem::Axes>;
return extended_axes;
}
template <typename Shape, typename Indices>
std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
......
......@@ -78,24 +78,19 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-10, 1e-10);
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
std::array<std::size_t, Problem::NumDim + 1> extended_shape;
std::copy(begin(shape), end(shape), begin(extended_shape));
extended_shape.back() = NUM_ELEMS_IN_BUNDLE;
using DataType = detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>;
const auto extended_shape = extend_shape(shape, NUM_ELEMS_IN_BUNDLE);
const auto extended_axes = extend_axes(problem.axes);
ck::remove_cvref_t<decltype(extended_shape)> transposed_extended_shape;
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_a(extended_shape);
std::memcpy(data(extended_a.mData),
data(a.mData),
sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
std::array<std::size_t, Problem::NumDim + 1> extended_axes;
std::copy(begin(problem.axes), end(problem.axes), begin(extended_axes));
extended_axes.back() = Problem::NumDim;
std::array<std::size_t, Problem::NumDim + 1> transposed_extended_shape;
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment