Commit 7eef6cfc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more test cases for the quantize function

parent c98b919f
......@@ -25,11 +25,11 @@ instruction_ref insert_fp16(program& prog,
instruction_ref ins_fp16{};
if(ins->name() == "@literal" && ins->outputs().size() == 1)
{
std::vector<float> values;
auto l_fp32 = ins->get_literal();
shape s = ins->get_shape();
l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
auto l = ins->get_literal();
auto s = ins->get_shape();
l.visit([&](auto val) {
ins_fp16 = prog.add_literal(literal({type, s.lens()}, val.begin(), val.end()));
});
}
else
{
......
......@@ -3,6 +3,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
......@@ -1557,4 +1558,40 @@ TEST_CASE(fp16_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fp32_fp16_test)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_literal(migraphx::literal(s, data));
p.add_instruction(migraphx::op::add{}, l1, l2);
return p;
};
{
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program();
migraphx::quantize(p, {"all"});
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res, gold_res));
}
{
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program();
migraphx::quantize(p, {"all"});
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res, gold_res));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -10,6 +10,7 @@
#include <migraphx/type_name.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <miopen/miopen.h>
......@@ -3334,4 +3335,70 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
template struct test_logsoftmax_1<0>;
template struct test_logsoftmax_1<1>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{
migraphx::program create_program () const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_literal(migraphx::literal(s, data));
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"});
return p;
};
};
struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
{
migraphx::program create_program () const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_literal(migraphx::literal(s, data));
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"});
return p;
};
};
struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"add"});
return p;
};
};
struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"sub"});
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -106,6 +106,23 @@ TEST_CASE(param_add_sub)
return p;
};
auto create_program_half_all = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto hp1 = p.insert_instruction(
std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1);
auto p2 = p.add_parameter("y", s);
auto hp2 = p.insert_instruction(
std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2);
auto hsum = p.add_instruction(migraphx::op::add{}, hp1, hp2);
auto hdiff = p.add_instruction(migraphx::op::sub{}, hsum, hp2);
auto hres = p.add_instruction(migraphx::op::add{}, hdiff, hp1);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hres);
return p;
};
{
auto p1 = create_program_float();
auto p2 = create_program_half_add();
......@@ -121,6 +138,16 @@ TEST_CASE(param_add_sub)
migraphx::quantize(p1, {"sub"});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_float();
auto p2 = create_program_half_all();
migraphx::quantize(p1, {"all"});
migraphx::run_passes(p1, {migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
}
}
TEST_CASE(literal_add)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment