"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d0e8bb1a36d9bbf8fc7b26efa9cdd210dac488aa"
Commit 092ec713 authored by turneram's avatar turneram
Browse files

Merge branch 'ck-gemm-int8' into gemm-perf

parents 11da0f48 b2f12dae
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
void apply_quantizelinear(module& m, instruction_ref ins) void apply_quantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); instruction_ref min_arg;
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); instruction_ref max_arg;
if (enabled(MIGRAPHX_ENABLE_CK{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg}); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
......
...@@ -141,6 +141,10 @@ struct find_ck_gemm_pointwise ...@@ -141,6 +141,10 @@ struct find_ck_gemm_pointwise
return not input->inputs().empty() and input->inputs().front()->name() == "capture"; return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment