Unverified Commit 1820198e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add nhwc layout to gpu backend (#1391)

Can be enabled via environment variable MIGRAPHX_ENABLE_NHWC
parent 2f48b11a
......@@ -35,6 +35,7 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp>
......@@ -50,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
......@@ -70,6 +72,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
......@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
......@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
compile_miopen{&gctx},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math},
dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}});
}
migraphx::operation layout(std::vector<int64_t> permutation = {0, 1, 2, 3})
{
return migraphx::make_op("layout", {{"permutation", permutation}});
}
migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruction_ref ins)
{
return m.add_instruction(layout({0, 2, 3, 1}), ins);
}
TEST_CASE(conv_relu)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
m1.add_instruction(migraphx::make_op("relu"), conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
m2.add_instruction(migraphx::make_op("relu"), conv_layout);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(conv_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m1.add_instruction(migraphx::make_op("add"), conv, b);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m2.add_instruction(migraphx::make_op("add"), conv_layout, b);
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment