Commit 98498486 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add transpose_shape() to generalize shape permute

parent 185f7844
...@@ -50,9 +50,7 @@ template <typename T> ...@@ -50,9 +50,7 @@ template <typename T>
struct is_iterator<T, struct is_iterator<T,
std::void_t<decltype(*std::declval<T>()), std::void_t<decltype(*std::declval<T>()),
decltype(++std::declval<std::add_lvalue_reference_t<T>>()), decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
decltype(--std::declval<std::add_lvalue_reference_t<T>>()), decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
decltype(std::declval<std::add_lvalue_reference_t<T>>()++),
decltype(std::declval<std::add_lvalue_reference_t<T>>()--)>>
: std::true_type : std::true_type
{ {
}; };
...@@ -60,6 +58,28 @@ struct is_iterator<T, ...@@ -60,6 +58,28 @@ struct is_iterator<T,
template <typename T> template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value; inline constexpr bool is_iterator_v = is_iterator<T>::value;
struct Placeholder final
{
template <typename T>
constexpr inline operator T() const noexcept;
};
template <typename T, typename = void>
struct is_output_iterator : std::false_type
{
};
template <typename Iterator>
struct is_output_iterator<
Iterator,
std::void_t<decltype(*std::declval<Iterator>() = std::declval<Placeholder>())>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};
template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;
template <typename Iterator, typename = void> template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type struct is_random_access_iterator : std::false_type
{ {
...@@ -92,6 +112,20 @@ struct is_range<T, ...@@ -92,6 +112,20 @@ struct is_range<T,
template <typename T> template <typename T>
inline constexpr bool is_range_v = is_range<T>::value; inline constexpr bool is_range_v = is_range<T>::value;
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};
template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
: std::bool_constant<is_range_v<Range>>
{
};
template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;
template <typename Range, typename = void> template <typename Range, typename = void>
struct is_random_access_range : std::false_type struct is_random_access_range : std::false_type
{ {
...@@ -111,7 +145,7 @@ inline constexpr bool is_random_access_range_v = is_random_access_range<Range>:: ...@@ -111,7 +145,7 @@ inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::
} // namespace detail } // namespace detail
template <typename Axes> template <typename Axes>
inline std::enable_if_t<detail::is_random_access_range_v<ck::remove_cvref_t<Axes>>, bool> inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
is_valid_axes(const Axes& axes) is_valid_axes(const Axes& axes)
{ {
using std::empty; using std::empty;
...@@ -129,6 +163,24 @@ is_valid_axes(const Axes& axes) ...@@ -129,6 +163,24 @@ is_valid_axes(const Axes& axes)
return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1); return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1);
} }
template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
detail::is_output_iterator_v<OutputIterator>,
OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{
using std::size;
assert(size(shape) == size(axes) && is_valid_axes(axes));
for(const auto axis : axes)
{
*iter++ = shape[axis];
}
return iter;
}
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{ {
constexpr int num_execution_config_args = 2; constexpr int num_execution_config_args = 2;
......
...@@ -24,9 +24,10 @@ void host_elementwise4D(HostTensorB& B, ...@@ -24,9 +24,10 @@ void host_elementwise4D(HostTensorB& B,
bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem) bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem)
{ {
auto [N, C, H, W] = problem.shape; const auto& nchw = problem.shape;
std::vector<std::size_t> nchw = {N, C, H, W}; std::vector<std::size_t> nhwc;
std::vector<std::size_t> nhwc = {N, H, W, C}; transpose_shape(problem.shape, problem.axes, std::back_inserter(nhwc));
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment