"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e76bd7293eb27828cab07c35395d898d7cec8eeb"
Commit 3de847f3 authored by Paul's avatar Paul
Browse files

Format

parent 2e138c91
...@@ -38,7 +38,7 @@ struct ck_gemm ...@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input:inputs) for(const auto& input : inputs)
check_gemm_shape(input); check_gemm_shape(input);
return op.compute_shape({a, b}); return op.compute_shape({a, b});
} }
...@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2) if(a.lens().size() > 2 or b.lens().size() > 2)
return false; return false;
if (a.lens()[1] >= 2048) if(a.lens()[1] >= 2048)
return false; return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0); b.lens()[1] % 8 == 0);
...@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm struct find_ck_gemm
{ {
// Find a gemm followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
auto matcher() const { auto matcher() const
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm"))); {
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
} }
...@@ -77,16 +79,17 @@ struct find_ck_gemm ...@@ -77,16 +79,17 @@ struct find_ck_gemm
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if (gemm_idx != 0) if(gemm_idx != 0)
{ {
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]); auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape()); auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape()); auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param); pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param); pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param); pm->remove_instruction(first_param);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment