// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once bool run_permute_bundle(const Problem& problem) { constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType); using std::begin, std::end; const auto& input_bundle_shape = problem.shape; const auto& input_bundle_axes = problem.axes; ck::remove_cvref_t output_bundle_shape; transpose_shape(input_bundle_shape, input_bundle_axes, begin(output_bundle_shape)); Tensor input_bundle_tensor(input_bundle_shape); Tensor output_bundle_tensor(output_bundle_shape); // initialize tensor by assigning DataType values using std::data, std::size; ck::utils::FillUniformDistribution{-1.f, 1.f}( ck::span{reinterpret_cast(data(input_bundle_tensor)), input_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle}); DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes()); DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes()); input_device_buf.ToDevice(data(input_bundle_tensor)); std::array input_bundle_lengths, output_bundle_lengths; std::array input_bundle_strides, output_bundle_strides; const void* input_bundle_data = input_device_buf.GetDeviceBuffer(); void* output_bundle_data = output_device_buf.GetDeviceBuffer(); std::copy(begin(input_bundle_shape), end(input_bundle_shape), begin(input_bundle_lengths)); std::copy(begin(input_bundle_tensor.GetStrides()), end(input_bundle_tensor.GetStrides()), begin(input_bundle_strides)); std::copy(begin(output_bundle_shape), end(output_bundle_shape), begin(output_bundle_lengths)); std::copy(begin(output_bundle_tensor.GetStrides()), end(output_bundle_tensor.GetStrides()), begin(output_bundle_strides)); static_assert(std::is_default_constructible_v); auto permute = DevicePermuteInstance{}; auto argument = permute.MakeArgument(input_bundle_lengths, input_bundle_strides, output_bundle_lengths, output_bundle_strides, input_bundle_data, output_bundle_data, PassThrough{}); if(!permute.IsSupportedArgument(argument)) { std::cerr << "The runtime parameters seems not supported by the device instance, exiting!" << std::endl; return false; }; auto invoker = permute.MakeInvoker(); float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); std::cout << "Perf: " << ave_time << " ms" << std::endl; output_device_buf.FromDevice(data(output_bundle_tensor)); // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle] const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle); const auto input_axes = extend_axes(input_bundle_axes); ck::remove_cvref_t output_shape; transpose_shape(input_shape, input_axes, begin(output_shape)); Tensor input_tensor(input_shape); std::memcpy(data(input_tensor), data(input_bundle_tensor), input_bundle_tensor.GetElementSpaceSizeInBytes()); Tensor output_tensor(output_shape); if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor)) { return false; } return ck::utils::check_err( ck::span{reinterpret_cast(data(output_bundle_tensor)), output_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle}, ck::span{output_tensor}, "Error: incorrect results in output tensor", 1e-6, 1e-6); } bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes) { return run_permute_bundle(Problem{shape, axes}); }