Commit 51bd00b3 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into so-version

parents a79beae7 b4f11615
Pipeline #673 failed with stages
in 0 seconds
...@@ -264,8 +264,7 @@ static void ins_quantize_int8(program& prog, ...@@ -264,8 +264,7 @@ static void ins_quantize_int8(program& prog,
auto dilation = conv_op.dilation; auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode; auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group; auto group = conv_op.group;
auto adjust_factor = auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
std::round(1.0f / (ins_quant_params[0].first * ins_quant_params[1].first));
auto quant_conv = prog.insert_instruction( auto quant_conv = prog.insert_instruction(
ins, ins,
......
...@@ -896,7 +896,7 @@ TEST_CASE(target_copy) ...@@ -896,7 +896,7 @@ TEST_CASE(target_copy)
} }
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization_dot)
{ {
auto run_prog = [](migraphx::program p, auto run_prog = [](migraphx::program p,
const migraphx::target& t, const migraphx::target& t,
...@@ -958,4 +958,47 @@ TEST_CASE(int8_quantization) ...@@ -958,4 +958,47 @@ TEST_CASE(int8_quantization)
} }
} }
TEST_CASE(int8_quantization_conv)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
std::vector<float>& res,
bool b_quantize = false) {
if(b_quantize)
{
std::vector<migraphx::program::parameter_map> cali_data;
migraphx::quantize_int8(p, t, cali_data);
}
p.compile(t);
migraphx::program::parameter_map m;
auto result = t.copy_from(p.eval(m));
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
auto create_program = [] {
migraphx::program p;
migraphx::shape sx{migraphx::shape::float_type, {4, 2, 2, 2}};
migraphx::shape sw{migraphx::shape::float_type, {4, 2, 2, 2}};
std::vector<float> v(sx.elements(), 0.5f);
auto input = p.add_literal(migraphx::literal(sx, v));
auto weights = p.add_literal(migraphx::literal(sw, v));
p.add_instruction(migraphx::op::convolution{}, input, weights);
return p;
};
{
auto p = create_program();
std::vector<float> quant_result;
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, quant_result, true);
std::vector<float> no_quant_result;
run_prog(p, cpu_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment