Commit b1d86d7c authored by Paul's avatar Paul
Browse files

Merge branch 'dot-add' into bert-opt2

parents 3b8ae098 9cb9bc09
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -975,9 +976,8 @@ struct find_gemm_pointwise ...@@ -975,9 +976,8 @@ struct find_gemm_pointwise
{ {
return precompile_name("pointwise")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any().bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
} }
...@@ -1053,6 +1053,14 @@ struct find_gemm_pointwise ...@@ -1053,6 +1053,14 @@ struct find_gemm_pointwise
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return; return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = op::contiguous{};
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment