Commit 8d02a27a authored by Paul's avatar Paul
Browse files

Move data type to is_ck_gemm matcher

parent d30df5bb
...@@ -53,6 +53,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -53,6 +53,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048) if(a.lens().back() > 2048)
...@@ -82,9 +85,6 @@ struct find_ck_gemm_pointwise ...@@ -82,9 +85,6 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(not contains({shape::half_type, shape::int8_type, shape::int32_type},
ins->get_shape().type()))
return;
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment