Commit b85e8611 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add to_array() conversion tool to eliminate more variables

parent fb05bd38
......@@ -210,8 +210,36 @@ struct is_random_access_range<Range, std::void_t<>>
template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;
template <typename Range>
class to_array_proxy
{
public:
explicit to_array_proxy(const Range& source) noexcept : source_(source) {}
template <typename T, std::size_t Size>
operator std::array<T, Size>() const
{
std::array<T, Size> destination;
std::copy_n(std::begin(source_),
std::min<std::size_t>(Size, std::size(source_)),
std::begin(destination));
return destination;
}
private:
const Range& source_;
};
} // namespace detail
template <typename Range>
inline auto to_array(Range& range) noexcept
{
return detail::to_array_proxy<ck::remove_cvref_t<Range>>{range};
}
namespace ranges {
template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
......
......@@ -28,24 +28,16 @@ bool run_permute_bundle(const Problem& problem)
input_device_buf.ToDevice(data(input_bundle_tensor));
std::array<ck::index_t, Problem::NumDim> input_bundle_lengths, output_bundle_lengths;
std::array<ck::index_t, Problem::NumDim> input_bundle_strides, output_bundle_strides;
const void* input_bundle_data = input_device_buf.GetDeviceBuffer();
void* output_bundle_data = output_device_buf.GetDeviceBuffer();
ranges::copy(input_bundle_shape, begin(input_bundle_lengths));
ranges::copy(input_bundle_tensor.GetStrides(), begin(input_bundle_strides));
ranges::copy(output_bundle_shape, begin(output_bundle_lengths));
ranges::copy(output_bundle_tensor.GetStrides(), begin(output_bundle_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(input_bundle_lengths,
input_bundle_strides,
output_bundle_lengths,
output_bundle_strides,
auto argument = permute.MakeArgument(to_array(input_bundle_shape),
to_array(input_bundle_tensor.GetStrides()),
to_array(output_bundle_shape),
to_array(output_bundle_tensor.GetStrides()),
input_bundle_data,
output_bundle_data,
PassThrough{});
......
......@@ -23,24 +23,16 @@ bool run_permute_element(const Problem& problem)
using std::data;
input_device_buf.ToDevice(data(input_tensor));
std::array<ck::index_t, Problem::NumDim> input_lengths, output_lengths;
std::array<ck::index_t, Problem::NumDim> input_strides, output_strides;
const void* input_data = input_device_buf.GetDeviceBuffer();
void* output_data = output_device_buf.GetDeviceBuffer();
ranges::copy(input_shape, begin(input_lengths));
ranges::copy(input_tensor.GetStrides(), begin(input_strides));
ranges::copy(output_shape, begin(output_lengths));
ranges::copy(output_tensor.GetStrides(), begin(output_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(input_lengths,
input_strides,
output_lengths,
output_strides,
auto argument = permute.MakeArgument(to_array(input_shape),
to_array(input_tensor.GetStrides()),
to_array(output_shape),
to_array(output_tensor.GetStrides()),
input_data,
output_data,
PassThrough{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment