Commit 4e55c401 authored by Paul's avatar Paul
Browse files

Some fixes

parent 40762b08
......@@ -31,6 +31,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
namespace migraphx {
......@@ -49,24 +50,27 @@ struct fused_reduce
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
{
MIGRAPHX_THROW("should have one submodule.");
}
auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
auto s = inputs.at(0);
auto lens = s.lens();
for(const auto& axis : axes)
if (lens != sm->get_output_shapes().front().lens())
{
lens[axis] = 1;
for(const auto& axis : axes)
{
lens[axis] = 1;
}
}
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens);
return shape::from_permutation(sm->get_output_shapes().front().type(), lens, find_permutation(inputs));
}
std::string name() const { return "fused_reduce"; }
};
MIGRAPHX_REGISTER_OP(fused_reduce);
static void create_reduce_modules(module_pass_manager& mpm)
{
......@@ -87,8 +91,8 @@ static void create_reduce_modules(module_pass_manager& mpm)
auto r = rm->add_instruction(ins->get_operator(), x0);
rm->add_return({r});
// TODO: Set axes
mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), ins->inputs(), {rm});
auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
}
}
......@@ -130,10 +134,11 @@ struct find_reduce_pointwise
auto ins = r.result;
auto reduce = r.instructions["reduce"];
auto* old_rm = reduce->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
// Copy module
*rm = *old_rm;
rm->set_bypass();
// Copy module instructions
rm->add_instructions(old_rm);
auto map_ins = get_ins_param_map(reduce->inputs(), rm);
auto new_inputs = reduce->inputs();
for(auto input : ins->inputs())
......@@ -152,8 +157,8 @@ struct find_reduce_pointwise
}
}
auto out = rm->insert_instructions(std::prev(rm->end()), {ins}, map_ins);
rm->replace_return(out);
auto out = rm->add_instructions({ins}, map_ins);
rm->add_return(out);
mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm});
}
};
......
......@@ -116,7 +116,7 @@ struct find_add_layernorm
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
// match::find_matches(m, find_add_layernorm{}, find_layernorm{});
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment