Commit 097506c3 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use rangified copy() to copy elements

parent ff6a04fd
......@@ -3,6 +3,7 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdlib>
......@@ -211,11 +212,18 @@ inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::
} // namespace detail
template <typename Range>
auto front(Range&& range) -> decltype(std::forward<Range>(range).front())
namespace ranges {
template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::forward<Range>(range).front();
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
} // namespace ranges
template <typename Axes>
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
......
......@@ -35,14 +35,10 @@ bool run_permute_bundle(const Problem& problem)
const void* input_bundle_data = input_device_buf.GetDeviceBuffer();
void* output_bundle_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(input_bundle_shape), end(input_bundle_shape), begin(input_bundle_lengths));
std::copy(begin(input_bundle_tensor.GetStrides()),
end(input_bundle_tensor.GetStrides()),
begin(input_bundle_strides));
std::copy(begin(output_bundle_shape), end(output_bundle_shape), begin(output_bundle_lengths));
std::copy(begin(output_bundle_tensor.GetStrides()),
end(output_bundle_tensor.GetStrides()),
begin(output_bundle_strides));
ranges::copy(input_bundle_shape, begin(input_bundle_lengths));
ranges::copy(input_bundle_tensor.GetStrides(), begin(input_bundle_strides));
ranges::copy(output_bundle_shape, begin(output_bundle_lengths));
ranges::copy(output_bundle_tensor.GetStrides(), begin(output_bundle_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
......
......@@ -30,12 +30,10 @@ bool run_permute_element(const Problem& problem)
const void* input_data = input_device_buf.GetDeviceBuffer();
void* output_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(input_shape), end(input_shape), begin(input_lengths));
std::copy(
begin(input_tensor.GetStrides()), end(input_tensor.GetStrides()), begin(input_strides));
std::copy(begin(output_shape), end(output_shape), begin(output_lengths));
std::copy(
begin(output_tensor.GetStrides()), end(output_tensor.GetStrides()), begin(output_strides));
ranges::copy(input_shape, begin(input_lengths));
ranges::copy(input_tensor.GetStrides(), begin(input_strides));
ranges::copy(output_shape, begin(output_lengths));
ranges::copy(output_tensor.GetStrides(), begin(output_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment