Commit f03d2369 authored by Paul's avatar Paul
Browse files

Compute the correct type

parent 15a7d96a
......@@ -44,7 +44,10 @@ struct ck_gemm
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
return op.compute_shape({a, b});
auto r = op.compute_shape({a, b});
if (mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment