Unverified Commit f33f2298 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve performance of pointwise/reduction kernels when using NHWC layouts (#1955)

* Improve performance of pointwise/reduction kernels when using NHWC layouts

* Format

* Add nhwc test

* Format

* Remove inline namespace

* Add reduce test
parent b164ceef
...@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_ ...@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const shape& s); MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const shape& s);
MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const std::vector<shape>& shapes); MIGRAPHX_EXPORT std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
/// Normalize the shapes so the order of dimensions will be in the order it is
/// in memory as much as possible.
MIGRAPHX_EXPORT std::vector<shape> normalize_permutation(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes) ...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first; return it->first;
} }
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(normalize_permutation(inputs));
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
......
...@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) ...@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
std::fill(lens.begin(), lens.end(), 1); std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes) for(const auto& axis : axes)
lens[axis] = s.lens()[axis]; lens[axis] = s.lens()[axis];
return shape{s.type(), lens}; return s.with_lens(lens);
} }
template <class T> template <class T>
...@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes) ...@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
auto lens = s.lens(); auto lens = s.lens();
for(const auto& axis : axes) for(const auto& axis : axes)
lens[axis] = 1; lens[axis] = 1;
return shape{s.type(), lens}; return s.with_lens(lens);
} }
template <class ReduceLens> template <class ReduceLens>
...@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto virtual_inputs = inputs; auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs));
auto reduce_output_shape = virtual_inputs.back(); auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back(); virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back(); auto reduction_shape = virtual_inputs.back();
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_nhwc : verify_program<test_add_nhwc>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation(
migraphx::shape::float_type, {4, 3, 8, 8}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation(
migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x);
auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce);
auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), abs);
mm->add_return({sqrt});
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment