Commit 097506c3 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use rangified copy() to copy elements

parent ff6a04fd
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdlib> #include <cstdlib>
...@@ -211,11 +212,18 @@ inline constexpr bool is_random_access_range_v = is_random_access_range<Range>:: ...@@ -211,11 +212,18 @@ inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::
} // namespace detail } // namespace detail
template <typename Range> namespace ranges {
auto front(Range&& range) -> decltype(std::forward<Range>(range).front()) template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{ {
return std::forward<Range>(range).front(); return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
} }
} // namespace ranges
template <typename Axes> template <typename Axes>
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool> inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
......
...@@ -35,14 +35,10 @@ bool run_permute_bundle(const Problem& problem) ...@@ -35,14 +35,10 @@ bool run_permute_bundle(const Problem& problem)
const void* input_bundle_data = input_device_buf.GetDeviceBuffer(); const void* input_bundle_data = input_device_buf.GetDeviceBuffer();
void* output_bundle_data = output_device_buf.GetDeviceBuffer(); void* output_bundle_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(input_bundle_shape), end(input_bundle_shape), begin(input_bundle_lengths)); ranges::copy(input_bundle_shape, begin(input_bundle_lengths));
std::copy(begin(input_bundle_tensor.GetStrides()), ranges::copy(input_bundle_tensor.GetStrides(), begin(input_bundle_strides));
end(input_bundle_tensor.GetStrides()), ranges::copy(output_bundle_shape, begin(output_bundle_lengths));
begin(input_bundle_strides)); ranges::copy(output_bundle_tensor.GetStrides(), begin(output_bundle_strides));
std::copy(begin(output_bundle_shape), end(output_bundle_shape), begin(output_bundle_lengths));
std::copy(begin(output_bundle_tensor.GetStrides()),
end(output_bundle_tensor.GetStrides()),
begin(output_bundle_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
......
...@@ -30,12 +30,10 @@ bool run_permute_element(const Problem& problem) ...@@ -30,12 +30,10 @@ bool run_permute_element(const Problem& problem)
const void* input_data = input_device_buf.GetDeviceBuffer(); const void* input_data = input_device_buf.GetDeviceBuffer();
void* output_data = output_device_buf.GetDeviceBuffer(); void* output_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(input_shape), end(input_shape), begin(input_lengths)); ranges::copy(input_shape, begin(input_lengths));
std::copy( ranges::copy(input_tensor.GetStrides(), begin(input_strides));
begin(input_tensor.GetStrides()), end(input_tensor.GetStrides()), begin(input_strides)); ranges::copy(output_shape, begin(output_lengths));
std::copy(begin(output_shape), end(output_shape), begin(output_lengths)); ranges::copy(output_tensor.GetStrides(), begin(output_strides));
std::copy(
begin(output_tensor.GetStrides()), end(output_tensor.GetStrides()), begin(output_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment