Commit f84580ac authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more unit tests for better code change coverage

parent 99a5995c
...@@ -368,6 +368,9 @@ void quantize_int8(program& prog, ...@@ -368,6 +368,9 @@ void quantize_int8(program& prog,
input->inputs().front()->get_shape().type() == quant_type) input->inputs().front()->get_shape().type() == quant_type)
{ {
quant_input = input->inputs().front(); quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::make_pair<float, float>(1.0f, 0.0f);
} }
else else
{ {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -627,6 +628,58 @@ TEST_CASE(dot_int32) ...@@ -627,6 +628,58 @@ TEST_CASE(dot_int32)
EXPECT(p == qp); EXPECT(p == qp);
} }
TEST_CASE(dot_float_convert)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
p.add_instruction(migraphx::op::dot{2.0f, 5.5f}, fpa, pb);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, pa, qb);
auto fr = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 10.0f);
auto new_alpha = p.add_literal(migraphx::literal(fr->get_shape(), v_alpha));
p.add_instruction(migraphx::op::mul{}, new_alpha, fr);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
TEST_CASE(conv_float) TEST_CASE(conv_float)
{ {
auto create_program = [] { auto create_program = [] {
...@@ -795,4 +848,117 @@ TEST_CASE(conv_half) ...@@ -795,4 +848,117 @@ TEST_CASE(conv_half)
EXPECT(p == qp); EXPECT(p == qp);
} }
TEST_CASE(target_copy)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
migraphx::program::parameter_map& m_in,
std::vector<float>& res) {
p.compile(t);
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m));
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
p.add_instruction(migraphx::op::add{}, p1, p2);
return p;
};
{
auto p = create_program();
migraphx::program::parameter_map m;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
m["x"] = migraphx::generate_argument(s);
std::vector<float> cpu_result;
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, m, cpu_result);
std::vector<float> orig_result;
run_prog(p, cpu_t, m, orig_result);
EXPECT(migraphx::verify_range(cpu_result, orig_result));
}
}
TEST_CASE(int8_quantization)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
migraphx::program::parameter_map& m_in,
std::vector<float>& res,
bool b_quantize = false) {
if (b_quantize)
{
std::vector<migraphx::program::parameter_map> cali_data;
cali_data.push_back(m_in);
migraphx::quantize_int8(p, t, cali_data);
}
p.compile(t);
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m));
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
auto create_program = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
return p;
};
{
auto p = create_program();
migraphx::program::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
m["a"] = migraphx::generate_argument(sa);
m["c"] = migraphx::generate_argument(sc);
std::vector<float> quant_result;
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, m, quant_result, true);
std::vector<float> no_quant_result;
run_prog(p, cpu_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment