Commit db32635c authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Generalize transpose utility functions

parent 98498486
......@@ -3,6 +3,7 @@
#pragma once
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include <iostream>
......@@ -64,7 +65,7 @@ struct Placeholder final
constexpr inline operator T() const noexcept;
};
template <typename T, typename = void>
template <typename Iterator, typename = void>
struct is_output_iterator : std::false_type
{
};
......@@ -80,6 +81,23 @@ struct is_output_iterator<
template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;
template <typename Iterator, typename = void>
struct is_bidirectional_iterator : std::false_type
{
};
template <typename Iterator>
struct is_bidirectional_iterator<
Iterator,
std::void_t<decltype(--std::declval<std::add_lvalue_reference_t<Iterator>>()),
decltype(std::declval<std::add_lvalue_reference_t<Iterator>>()--)>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};
template <typename Iterator>
inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator<Iterator>::value;
template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
......@@ -126,6 +144,22 @@ struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;
template <typename Range, typename = void>
struct is_bidirectional_range : std::false_type
{
};
template <typename Range>
struct is_bidirectional_range<Range, std::void_t<>>
: std::bool_constant<
is_range_v<Range> &&
is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};
template <typename Range>
inline constexpr bool is_bidirectional_range_v = is_bidirectional_range<Range>::value;
template <typename Range, typename = void>
struct is_random_access_range : std::false_type
{
......@@ -155,30 +189,13 @@ is_valid_axes(const Axes& axes)
}
using std::begin, std::end;
std::vector<std::size_t> copy(begin(axes), end(axes));
std::vector<std::size_t> sorted_axes(begin(axes), end(axes));
std::sort(begin(copy), end(copy));
const auto last = std::unique(begin(copy), end(copy));
std::sort(begin(sorted_axes), end(sorted_axes));
const auto last = std::unique(begin(sorted_axes), end(sorted_axes));
return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1);
}
template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
detail::is_output_iterator_v<OutputIterator>,
OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
using std::size;
assert(size(shape) == size(axes) && is_valid_axes(axes));
for(const auto axis : axes)
{
*iter++ = shape[axis];
}
return iter;
return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) &&
(*std::prev(last) == size(axes) - 1);
}
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
......@@ -235,3 +252,114 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
return true;
}
template <typename Shape>
inline std::enable_if_t<detail::is_range_v<Shape>, bool> is_valid_shape(const Shape& shape)
{
using std::begin, std::end;
using std::empty;
return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; });
}
template <typename Shape, typename Indices>
inline std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool>
is_valid_indices(const Shape& shape, const Indices& indices)
{
assert(is_valid_shape(shape));
using std::empty;
if(empty(indices))
{
return false;
}
using std::size;
if(size(shape) != size(indices))
{
return false;
}
using std::begin, std::end;
auto dim = begin(shape);
auto idx = begin(indices);
for(; dim != end(shape) && idx != end(indices); ++dim, ++idx)
{
if(*dim <= *idx)
{
return false;
}
}
return true;
}
template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
detail::is_output_iterator_v<OutputIterator>,
OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
using std::size;
assert(size(shape) == size(axes) &&);
assert(is_valid_shape(shape) && is_valid_axes(axes));
for(const auto axis : axes)
{
*iter++ = shape[axis];
}
return iter;
}
template <typename Shape, typename Indices>
std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
bool>
advance_indices(const Shape& shape, Indices& indices)
{
assert(is_valid_shape(shape));
assert(is_valid_indices(indices));
assert(size(shape) == size(indices));
bool carry = true;
using std::rbegin, std::rend;
auto dim = rbegin(shape);
auto idx = rbegin(indices);
for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx)
{
assert(*idx < *dim);
*idx = (*idx + carry);
carry = ((*idx == *dim) ? (*idx = 0, true) : false);
}
return !carry;
}
template <typename Src, typename Functor, typename Dest>
std::enable_if_t<std::is_invocable_v<Functor,
std::add_lvalue_reference_t<Dest>,
std::add_lvalue_reference_t<Src>>>
host_elementwise_permute(const Tensor<Src>& src, Functor functor, Tensor<Dest>& dest)
{
const auto& shape = src.mDesc.GetLengths();
const auto& transposed_shape = dest.mDesc.GetLengths();
assert(is_valid_shape(shape) && is_valid_shape(transposed_shape));
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
using std::size;
assert(size(shape) == 4 && size(transposed_shape) == 4);
std::array<std::size_t, 4> dims{};
do
{
Dest b_val = 0;
functor(b_val, src(dims[0], dims[1], dims[2], dims[3]));
dest(dims[0], dims[2], dims[3], dims[1]) = b_val;
} while(advance_indices(shape, dims));
}
......@@ -3,25 +3,6 @@
#pragma once
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B,
const HostTensorA& A,
const std::vector<std::size_t>& shape,
Functor functor)
{
using btype = ck::remove_reference_t<decltype(B(0, 0, 0, 0))>;
for(std::size_t n = 0; n < shape[0]; ++n)
for(std::size_t c = 0; c < shape[1]; ++c)
for(std::size_t h = 0; h < shape[2]; ++h)
for(std::size_t w = 0; w < shape[3]; ++w)
{
auto a_val = A(n, c, h, w);
btype b_val = 0;
functor(b_val, a_val);
B(n, h, w, c) = b_val;
}
}
bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem)
{
const auto& nchw = problem.shape;
......@@ -67,10 +48,10 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
if(config.do_verification)
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nhwc, PassThrough{});
host_elementwise_permute(a, PassThrough{}, host_b);
b_device_buf.FromDevice(b.mData.data());
return ck::utils::check_err(
b.mData, host_b.mData, "Error: incorrect results in tensor B", 1e-10, 1e-10);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment