Commit fb05bd38 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use function return value directly to eliminate variables

parent 097506c3
......@@ -289,23 +289,20 @@ is_valid_indices(const Shape& shape, const Indices& indices)
return true;
}
template <typename Shape, typename Axes, typename OutputIterator>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> &&
detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> &&
detail::is_output_iterator_v<OutputIterator>,
OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
template <std::size_t Size>
std::array<std::size_t, Size> transpose(const std::array<std::size_t, Size>& shape,
const std::array<std::size_t, Size>& axes)
{
using std::size;
assert(size(shape) == size(axes));
assert(is_valid_shape(shape) && is_valid_axes(axes));
std::array<std::size_t, Size> transposed;
auto iter = std::begin(transposed);
for(const auto axis : axes)
{
*iter++ = shape[axis];
}
return iter;
return transposed;
}
auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
......
......@@ -12,8 +12,7 @@ bool run_permute_bundle(const Problem& problem)
const auto& input_bundle_shape = problem.shape;
const auto& input_bundle_axes = problem.axes;
ck::remove_cvref_t<decltype(input_bundle_shape)> output_bundle_shape;
transpose_shape(input_bundle_shape, input_bundle_axes, begin(output_bundle_shape));
const auto output_bundle_shape = transpose(input_bundle_shape, input_bundle_axes);
Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
......@@ -66,18 +65,16 @@ bool run_permute_bundle(const Problem& problem)
output_device_buf.FromDevice(data(output_bundle_tensor));
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
// axes from [0, 2, 1] to [0, 2, 1, 3]
const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
const auto input_axes = extend_axes(input_bundle_axes);
ck::remove_cvref_t<decltype(input_shape)> output_shape;
transpose_shape(input_shape, input_axes, begin(output_shape));
Tensor<DataType> input_tensor(input_shape);
std::memcpy(data(input_tensor),
data(input_bundle_tensor),
input_bundle_tensor.GetElementSpaceSizeInBytes());
Tensor<DataType> output_tensor(output_shape);
Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
{
return false;
......
......@@ -10,8 +10,7 @@ bool run_permute_element(const Problem& problem)
const auto& input_shape = problem.shape;
const auto& input_axes = problem.axes;
ck::remove_cvref_t<decltype(input_shape)> output_shape;
transpose_shape(input_shape, input_axes, begin(output_shape));
const auto output_shape = transpose(input_shape, input_axes);
Tensor<InDataType> input_tensor(input_shape);
Tensor<OutDataType> output_tensor(output_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment