"examples/vscode:/vscode.git/clone" did not exist on "e607a582cfaa7dfaf7913fc3bb54c35eceee583c"
Commit fb05bd38 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use function return value directly to eliminate variables

parent 097506c3
...@@ -289,23 +289,20 @@ is_valid_indices(const Shape& shape, const Indices& indices) ...@@ -289,23 +289,20 @@ is_valid_indices(const Shape& shape, const Indices& indices)
return true; return true;
} }
template <typename Shape, typename Axes, typename OutputIterator> template <std::size_t Size>
inline std::enable_if_t<detail::is_random_access_range_v<Shape> && std::array<std::size_t, Size> transpose(const std::array<std::size_t, Size>& shape,
detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Axes> && const std::array<std::size_t, Size>& axes)
detail::is_output_iterator_v<OutputIterator>,
OutputIterator>
transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
{ {
using std::size;
assert(size(shape) == size(axes));
assert(is_valid_shape(shape) && is_valid_axes(axes)); assert(is_valid_shape(shape) && is_valid_axes(axes));
std::array<std::size_t, Size> transposed;
auto iter = std::begin(transposed);
for(const auto axis : axes) for(const auto axis : axes)
{ {
*iter++ = shape[axis]; *iter++ = shape[axis];
} }
return iter; return transposed;
} }
auto extend_shape(const Problem::Shape& shape, std::size_t new_dim) auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
......
...@@ -12,8 +12,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -12,8 +12,7 @@ bool run_permute_bundle(const Problem& problem)
const auto& input_bundle_shape = problem.shape; const auto& input_bundle_shape = problem.shape;
const auto& input_bundle_axes = problem.axes; const auto& input_bundle_axes = problem.axes;
ck::remove_cvref_t<decltype(input_bundle_shape)> output_bundle_shape; const auto output_bundle_shape = transpose(input_bundle_shape, input_bundle_axes);
transpose_shape(input_bundle_shape, input_bundle_axes, begin(output_bundle_shape));
Tensor<BundleType> input_bundle_tensor(input_bundle_shape); Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
Tensor<BundleType> output_bundle_tensor(output_bundle_shape); Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
...@@ -66,18 +65,16 @@ bool run_permute_bundle(const Problem& problem) ...@@ -66,18 +65,16 @@ bool run_permute_bundle(const Problem& problem)
output_device_buf.FromDevice(data(output_bundle_tensor)); output_device_buf.FromDevice(data(output_bundle_tensor));
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle] // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
// axes from [0, 2, 1] to [0, 2, 1, 3]
const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle); const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
const auto input_axes = extend_axes(input_bundle_axes); const auto input_axes = extend_axes(input_bundle_axes);
ck::remove_cvref_t<decltype(input_shape)> output_shape;
transpose_shape(input_shape, input_axes, begin(output_shape));
Tensor<DataType> input_tensor(input_shape); Tensor<DataType> input_tensor(input_shape);
std::memcpy(data(input_tensor), std::memcpy(data(input_tensor),
data(input_bundle_tensor), data(input_bundle_tensor),
input_bundle_tensor.GetElementSpaceSizeInBytes()); input_bundle_tensor.GetElementSpaceSizeInBytes());
Tensor<DataType> output_tensor(output_shape); Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor)) if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
{ {
return false; return false;
......
...@@ -10,8 +10,7 @@ bool run_permute_element(const Problem& problem) ...@@ -10,8 +10,7 @@ bool run_permute_element(const Problem& problem)
const auto& input_shape = problem.shape; const auto& input_shape = problem.shape;
const auto& input_axes = problem.axes; const auto& input_axes = problem.axes;
ck::remove_cvref_t<decltype(input_shape)> output_shape; const auto output_shape = transpose(input_shape, input_axes);
transpose_shape(input_shape, input_axes, begin(output_shape));
Tensor<InDataType> input_tensor(input_shape); Tensor<InDataType> input_tensor(input_shape);
Tensor<OutDataType> output_tensor(output_shape); Tensor<OutDataType> output_tensor(output_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment