"docs/EN/source/vscode:/vscode.git/clone" did not exist on "0a900eca0249bdede70d951d3ffa7fb72b87bd04"
Commit 3de847f3 authored by Paul's avatar Paul
Browse files

Format

parent 2e138c91
......@@ -38,7 +38,7 @@ struct ck_gemm
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input:inputs)
for(const auto& input : inputs)
check_gemm_shape(input);
return op.compute_shape({a, b});
}
......@@ -55,7 +55,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if (a.lens()[1] >= 2048)
if(a.lens()[1] >= 2048)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
......@@ -64,8 +64,10 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm
{
// Find a gemm followed by a pointwise operation.
auto matcher() const {
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
auto matcher() const
{
auto gemm =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
......@@ -77,16 +79,17 @@ struct find_ck_gemm
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if (gemm_idx != 0)
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment