Commit 4e55c401 authored by Paul's avatar Paul
Browse files

Some fixes

parent 40762b08
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator> #include <iterator>
namespace migraphx { namespace migraphx {
...@@ -49,24 +50,27 @@ struct fused_reduce ...@@ -49,24 +50,27 @@ struct fused_reduce
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{ {
if(mods.size() != 1) if(mods.size() != 1)
{
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
}
auto* sm = mods.front(); auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims(); check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
if (lens != sm->get_output_shapes().front().lens())
{
for(const auto& axis : axes) for(const auto& axis : axes)
{ {
lens[axis] = 1; lens[axis] = 1;
} }
if(sm->get_output_shapes().size() != 1) }
MIGRAPHX_THROW("Only one output supported");
return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens); return shape::from_permutation(sm->get_output_shapes().front().type(), lens, find_permutation(inputs));
} }
std::string name() const { return "fused_reduce"; } std::string name() const { return "fused_reduce"; }
}; };
MIGRAPHX_REGISTER_OP(fused_reduce);
static void create_reduce_modules(module_pass_manager& mpm) static void create_reduce_modules(module_pass_manager& mpm)
{ {
...@@ -87,8 +91,8 @@ static void create_reduce_modules(module_pass_manager& mpm) ...@@ -87,8 +91,8 @@ static void create_reduce_modules(module_pass_manager& mpm)
auto r = rm->add_instruction(ins->get_operator(), x0); auto r = rm->add_instruction(ins->get_operator(), x0);
rm->add_return({r}); rm->add_return({r});
// TODO: Set axes auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), ins->inputs(), {rm}); mpm.get_module().replace_instruction(ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
} }
} }
...@@ -130,10 +134,11 @@ struct find_reduce_pointwise ...@@ -130,10 +134,11 @@ struct find_reduce_pointwise
auto ins = r.result; auto ins = r.result;
auto reduce = r.instructions["reduce"]; auto reduce = r.instructions["reduce"];
auto* old_rm = reduce->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise"); auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
// Copy module rm->set_bypass();
*rm = *old_rm; // Copy module instructions
rm->add_instructions(old_rm);
auto map_ins = get_ins_param_map(reduce->inputs(), rm); auto map_ins = get_ins_param_map(reduce->inputs(), rm);
auto new_inputs = reduce->inputs(); auto new_inputs = reduce->inputs();
for(auto input : ins->inputs()) for(auto input : ins->inputs())
...@@ -152,8 +157,8 @@ struct find_reduce_pointwise ...@@ -152,8 +157,8 @@ struct find_reduce_pointwise
} }
} }
auto out = rm->insert_instructions(std::prev(rm->end()), {ins}, map_ins); auto out = rm->add_instructions({ins}, map_ins);
rm->replace_return(out); rm->add_return(out);
mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm});
} }
}; };
......
...@@ -116,7 +116,7 @@ struct find_add_layernorm ...@@ -116,7 +116,7 @@ struct find_add_layernorm
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_add_layernorm{}, find_layernorm{}); // match::find_matches(m, find_add_layernorm{}, find_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment