Commit 43010679 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed review comments.

parent aaf8ed08
...@@ -23,18 +23,7 @@ instruction_ref insert_fp16(program& prog, ...@@ -23,18 +23,7 @@ instruction_ref insert_fp16(program& prog,
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type); ins->get_shape().type() == shape::double_type);
instruction_ref ins_fp16{}; instruction_ref ins_fp16{};
if(ins->name() == "@literal" && ins->outputs().size() == 1) ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
{
auto l = ins->get_literal();
auto s = ins->get_shape();
l.visit([&](auto val) {
ins_fp16 = prog.add_literal(literal({type, s.lens()}, val.begin(), val.end()));
});
}
else
{
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
}
map_fp16[ins] = ins_fp16; map_fp16[ins] = ins_fp16;
return ins_fp16; return ins_fp16;
......
...@@ -16,8 +16,7 @@ struct hip_convert : unary_device<hip_convert, device::convert> ...@@ -16,8 +16,7 @@ struct hip_convert : unary_device<hip_convert, device::convert>
{ {
op::convert op; op::convert op;
hip_convert(const op::convert& oper) : op(oper) {} hip_convert(const op::convert oper) : op(std::move(oper)) {}
hip_convert(const op::convert&& oper) : op(oper) {}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -3333,37 +3333,37 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3333,37 +3333,37 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
template struct test_logsoftmax_1<0>; template struct test_logsoftmax_1<0>;
template struct test_logsoftmax_1<1>; template struct test_logsoftmax_1<1>;
// struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
// { {
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
// std::vector<float> data(2 * 3); std::vector<float> data(2 * 3);
// std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
// auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
// auto l2 = p.add_literal(migraphx::literal(s, data)); auto l2 = p.add_parameter("p2", s);
// p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
// //migraphx::quantize(p, {"all"}); migraphx::quantize(p, {"all"});
// return p; return p;
// }; };
// }; };
// struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd> struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
// { {
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
// std::vector<float> data(2 * 3); std::vector<float> data(2 * 3);
// std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
// auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
// auto l2 = p.add_literal(migraphx::literal(s, data)); auto l2 = p.add_parameter("p2", s);
// p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
// migraphx::quantize(p, {"all"}); migraphx::quantize(p, {"add"});
// return p; return p;
// }; };
// }; };
struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add> struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
{ {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -180,7 +181,8 @@ TEST_CASE(literal_add) ...@@ -180,7 +181,8 @@ TEST_CASE(literal_add)
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1, {"all"}); migraphx::quantize(p1, {"all"});
migraphx::run_passes(p1, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p1, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -190,7 +192,8 @@ TEST_CASE(literal_add) ...@@ -190,7 +192,8 @@ TEST_CASE(literal_add)
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1, {"add"}); migraphx::quantize(p1, {"add"});
migraphx::run_passes(p1, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p1, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment