Unverified Commit 08360e83 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix compile failure in reduction fusion of instance norm (#1702)

This fixes #1700
parent 4339af75
...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{ {
params.push_back({pname, "T" + pname}); params.push_back({pname, "T" + pname});
......
...@@ -78,6 +78,7 @@ struct cpp_generator ...@@ -78,6 +78,7 @@ struct cpp_generator
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m); function& set_generic_types(const module& m);
function& add_generic_param(const std::string& pname); function& add_generic_param(const std::string& pname);
function& unused_param(const std::string& pname);
}; };
cpp_generator(); cpp_generator();
......
...@@ -280,6 +280,14 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -280,6 +280,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted(); not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = auto call_function =
...@@ -308,6 +316,8 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -308,6 +316,8 @@ std::string generate_reduce(const module& m, const std::string& name)
}); });
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r"); f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f); g.create_function(f);
return g.str(); return g.str();
} }
......
...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F> ...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f) __device__ void fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{}) if constexpr(reduce::is_inner_storage<decltype(result)>{})
{ {
r.inner([&](auto& y, auto x) { y = x; })(output, result); r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_reduce_mean_bias_half : verify_program<test_reduce_mean_bias_half>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {1, 32, 128}};
migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}};
auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x);
auto bias = mm->add_parameter("bias", reduce->get_shape());
auto add = mm->add_instruction(migraphx::make_op("add"), reduce, bias);
auto abs = mm->add_instruction(migraphx::make_op("abs"), add);
auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), abs);
mm->add_return({sqrt});
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment