Commit c27a6376 authored by Paul's avatar Paul
Browse files

Const fold adds for gemms

parent cf9cec1c
......@@ -50,6 +50,7 @@
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <set>
......@@ -974,9 +975,8 @@ struct find_gemm_pointwise
{
return precompile_name("pointwise")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(
match::any().bind("c"),
match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
......@@ -1052,6 +1052,14 @@ struct find_gemm_pointwise
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if (not c_ins->get_shape().standard())
{
auto c = op::contiguous{};
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs();
inputs.pop_back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment