Unverified Commit 5592b921 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Support multi program outputs (#436)



* Add initial api

* Formatting

* Add more api

* Formatting

* Add auto api generation

* Formatting

* Fix some compilation errors

* Change handle struct

* Formatting

* Fix reamining compilation errors

* Formatting

* fixed a bug related to number of outputs

* Simplify using ctype

* Formatting

* Initial c++ generation

* Formatting

* Add C++header

* Formatting

* Add test

* Formatting

* Add initial tests

* Formatting

* Try to fix formatting

* Cleanup formatting

* Formatting

* Fix constructors on the same line

* Fix tests

* Formatting

* Fix tidy issues

* Fix tidy issues

* Fix naming issue

* Add onnx API to parse buffer

* Formatting

* Add arguments api

* Formatting

* Fix verify parameters

* Fix cppcheck issues

* Formatting

* Add method to get output shapes and bytes

* Formatting

* Try formatting

* Formatting

* Improve the test coverage

* Formatting

* Add print method

* Formatting

* Fix cppcheck issue

* Fix package dependency

* code backup for support multiple outputs

* clang format

* change migraphx api to support multiple program outputs

* clang format

* change api implementation

* clang format

* clang format

* fix a build error

* additional changes

* clang format

* change api for correct automatic generation

* clang format

* fix unit test error

* fix unit test error

* fix unit tests error

* support multiple program outputs

* clang format

* remove @ from the add_return name

* Add nolint

* Try fix formatting

* Formatting

* formatting

* formatting

* Fix formatting

* code cleanup

* clang format

* fix cppcheck error

* fix a cppcheck error

* clang format

* fix review comments

* clang format

* fix cppcheck error

* clang format

* record graph output name

* clang format

* refine print the add_return instruction

* clang format

* fix cppcheck error

* clang format

* refine the name of the add_return instruction

* fixed a bug related to workspace

* fixed two small bugs

* clang format

* add more unit tests for multiple program outputs

* clang format

* change an error info

* clang format

* fix cppcheck error

* add unit test for better code coverage

* change to reduce code change

* clang format

* remove storing program output

* fix cppcheck error

* fix review comments

* clang format

* clang format

* remove unnecessary change

* resolve an assert error

* clang format

* change the output name with prefix '#'

* changes in quantization function to support the returns instructin

* clang format

* refine unit tests

* clang format

* refine profiling print out report
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarKhalique <15948690+kahmed10@users.noreply.github.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7461322a
......@@ -69,6 +69,10 @@ void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// return instruction should have inputs with standard shape
if(ins->name() == "@return")
continue;
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
......
......@@ -63,6 +63,16 @@ struct param
}
};
struct returns
{
std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("builtin");
}
};
} // namespace builtin
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -93,7 +93,7 @@ struct instruction
void replace(const shape& r);
operation op;
shape result;
shape result{};
std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments;
literal lit;
......
......@@ -87,6 +87,8 @@ struct program
instruction_ref add_parameter(std::string name, shape s);
instruction_ref add_return(std::vector<instruction_ref> args);
shape get_parameter_shape(std::string name) const;
instruction_ref get_parameter(std::string name) const;
......
......@@ -22,6 +22,9 @@ void instruction::replace(const shape& r)
result = r;
for(auto&& ins : output)
{
if(ins->name() == "@return")
continue;
assert(ins->name().front() != '@');
ins->recompute_shape();
}
......@@ -70,6 +73,10 @@ bool instruction::valid() const
{
computed = result;
}
else if(op.name() == "@return")
{
computed = {};
}
else
{
try
......@@ -81,6 +88,7 @@ bool instruction::valid() const
return false;
}
}
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
......
......@@ -1759,18 +1759,28 @@ struct onnx_parser
this->parse_node(output.name());
}
// For now, the last output with a valid name is considered
// as the program output, and add an identity instruction at
// the program end
// Find instructions corresponding to the output
auto prog_output = graph.output();
auto oit = std::find_if(prog_output.rbegin(), prog_output.rend(), [](auto& node) {
return !node.name().empty();
});
if(instructions.count(oit->name()) > 0)
{
prog.add_instruction(op::identity{}, instructions[oit->name()]);
}
std::vector<std::string> all_output_names;
std::vector<std::string> prog_output_names;
std::transform(prog_output.begin(),
prog_output.end(),
std::back_inserter(all_output_names),
[](auto& node) { return node.name(); });
std::copy_if(
all_output_names.begin(),
all_output_names.end(),
std::back_inserter(prog_output_names),
[&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });
std::vector<instruction_ref> output_ins;
std::transform(prog_output_names.begin(),
prog_output_names.end(),
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
// add the return instuction
prog.add_return(output_ins);
}
void parse_undefined(const std::string& name)
......@@ -1816,9 +1826,9 @@ struct onnx_parser
}
else
{
assert(node.output().size() <= result.size());
auto output_num = std::min<std::size_t>(node.output().size(), result.size());
std::transform(node.output().begin(),
node.output().end(),
node.output().begin() + output_num,
result.begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); });
......
......@@ -52,7 +52,9 @@ static void print_instruction(std::ostream& os,
os << ")";
}
os << " -> " << ins->get_shape();
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
template <class F>
......@@ -147,7 +149,14 @@ void program::assign(const program& p)
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
if(ins->name() == "@return")
{
copy_ins = add_return(copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
}
ins_map[ins] = copy_ins;
......@@ -270,6 +279,18 @@ instruction_ref program::add_parameter(std::string name, shape s)
return impl->instructions.begin();
}
instruction_ref program::add_return(std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
impl->instructions.push_back({builtin::returns{}, {}, args});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
shape program::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
......@@ -336,7 +357,23 @@ instruction_ref program::end() const { return impl->instructions.end(); }
std::vector<shape> program::get_output_shapes() const
{
return {impl->instructions.back().get_shape()};
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
auto& output_ins = last_ins.inputs();
std::vector<shape> output_shapes;
std::transform(output_ins.begin(),
output_ins.end(),
std::back_inserter(output_shapes),
[](auto& ins) { return ins->get_shape(); });
return output_shapes;
}
// The else branch is to provide backward compatibility
else
{
return {last_ins.get_shape()};
}
}
context& program::get_context() const { return impl->ctx; }
......@@ -410,6 +447,19 @@ std::vector<argument> generic_eval(const program& p,
{
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
}
else if(name == "@return")
{
std::vector<argument> prog_outputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(prog_outputs),
[&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
return prog_outputs;
}
else
{
values.resize(ins->inputs().size());
......@@ -424,10 +474,11 @@ std::vector<argument> generic_eval(const program& p,
}
assert(results.find(ins) != results.end());
}
return {results.at(std::prev(p.end()))};
}
std::vector<argument> program::eval(std::unordered_map<std::string, argument> params) const
std::vector<argument> program::eval(parameter_map params) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
......@@ -534,6 +585,11 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
print_program(*this, [&](auto ins, const auto& names) {
print_instruction(std::cout, ins, names);
// skip return instruction
if(ins->name() == "@return")
return;
double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%";
......
......@@ -105,6 +105,9 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
{
if(ins->name() == "@return")
break;
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
......@@ -335,6 +338,9 @@ void quantize_int8_impl(program& prog,
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog))
{
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name()))
{
continue;
......
......@@ -84,6 +84,7 @@ struct miopen_apply
const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
std::unordered_map<instruction_ref, std::string> prog_output_names{};
context& get_context()
{
......@@ -99,11 +100,33 @@ struct miopen_apply
(void)i;
}
void create_output_names()
{
this->last = instruction::get_output_alias(std::prev(prog->end()));
if(this->last->name() == "@return")
{
auto& prog_outputs = last->inputs();
std::vector<instruction_ref> outputs_alias(prog_outputs.size());
std::transform(prog_outputs.begin(),
prog_outputs.end(),
outputs_alias.begin(),
[](const auto& i) { return instruction::get_output_alias(i); });
std::size_t index = 0;
for(auto ins : outputs_alias)
{
prog_output_names[ins] = "#output_" + std::to_string(index++);
}
}
}
void init()
{
assert(prog != nullptr);
assert(pass != nullptr);
this->last = instruction::get_output_alias(std::prev(prog->end()));
create_output_names();
add_miopen_simple_op<miopen_abs>("abs", make_abs);
......@@ -172,17 +195,37 @@ struct miopen_apply
{
if(not pass->offload_copy)
return;
for(auto ins : iterator_for(*prog))
{
if(ins->name() != "@param")
continue;
auto pos = std::next(ins);
auto a = insert_allocation(pos, ins->get_shape());
auto c = prog->insert_instruction(pos, hip_copy_to_gpu{}, ins, a);
prog->replace_instruction(ins, c);
}
auto end = std::prev(prog->end());
prog->add_instruction(hip_copy_from_gpu{}, end);
// return instruction
auto ret = std::prev(prog->end());
if(ret->name() == "@return")
{
auto& inputs = ret->inputs();
// each input of ret need to be copied from gpu to host, and replace
// output with copy output
for(auto& in : inputs)
{
auto p_output = prog->insert_instruction(ret, hip_copy_from_gpu{}, in);
instruction::replace_argument(ret, in, p_output);
}
}
// else branch to handle legacy program without the return instruction
else
{
prog->add_instruction(hip_copy_from_gpu{}, ret);
}
}
void apply()
......@@ -196,20 +239,30 @@ struct miopen_apply
check_shape(s, apply_map.at(it->name())(it));
}
}
copy_params();
}
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{
if(not pass->offload_copy and ins == last and tag.empty())
{
return prog->add_parameter("output", s);
}
else
// Instruction's output is an input of the ret instruction
if(pass->offload_copy)
{
auto result = prog->insert_instruction(ins, hip_allocate{s, std::move(tag)});
return result;
}
auto ins_alias = instruction::get_output_alias(ins);
if(last->name() == "@return" and tag.empty() and prog_output_names.count(ins_alias) > 0)
{
return prog->add_parameter(prog_output_names[ins_alias], s);
}
else if(ins == last and tag.empty())
{
return prog->add_parameter("output", s);
}
return prog->insert_instruction(ins, hip_allocate{s, std::move(tag)});
}
void add_convolution_op()
......
......@@ -1030,6 +1030,9 @@ struct tf_parser
{
this->parse_node(p.first);
}
// Needs to add a ret instruction at the end of
// the program
}
void parse_node(const std::string& name)
......
......@@ -112,4 +112,16 @@ TEST_CASE(no_packed_unary_op)
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
TEST_CASE(non_standard_return_input)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto tl = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, tl);
p.add_return({c});
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/miopen.hpp>
......@@ -87,17 +88,29 @@ auto get_hash(const T& x)
void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false)
{
auto name = t.name();
auto s = p.get_output_shapes().back();
auto name = t.name();
auto shapes = p.get_output_shapes();
std::stringstream ss;
migraphx::compile_options options;
options.trace = migraphx::tracer{ss};
p.compile(t, options);
if(p.get_output_shapes().back() != s)
if(shapes.size() != p.get_output_shapes().size())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape");
throw std::runtime_error("Compiling program with " + name +
" alters its number of outputs");
}
auto num = shapes.size();
for(std::size_t i = 0; i < num; ++i)
{
if(p.get_output_shapes()[i].lens() != shapes[i].lens())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape");
}
}
if(show_trace)
{
std::cout << ss.str() << std::endl;
......@@ -133,7 +146,9 @@ std::vector<migraphx::argument> run_gpu(migraphx::program& p)
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
}
// Program should have an output parameter
EXPECT(bool{m.find("output") != m.end()});
EXPECT(std::any_of(
m.begin(), m.end(), [](auto& x) { return migraphx::contains(x.first, "output"); }));
// Ensure the program doesn't modify the context in a dry run
auto ctx = p.get_context();
assert(&ctx != &p.get_context());
......@@ -417,6 +432,21 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
}
};
struct test_trans_tanh1 : verify_program<test_trans_tanh1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_return({tx, r});
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
......@@ -1819,6 +1849,19 @@ struct test_transpose : verify_program<test_transpose>
}
};
struct test_trans_ret : verify_program<test_trans_ret>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
p.add_return({tx});
return p;
}
};
struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{
const size_t width = 14;
......@@ -2439,6 +2482,48 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
}
};
struct test_rnn_two_outputs : verify_program<test_rnn_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_rnn_reverse : verify_program<test_rnn_reverse>
{
migraphx::program create_program() const
......@@ -2959,6 +3044,38 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a
}
};
struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{
migraphx::program create_program() const
......@@ -3509,6 +3626,79 @@ struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
}
};
struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_return({hs, last_hs, last_cell});
return p;
}
};
struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{
migraphx::program create_program() const
......
......@@ -18,7 +18,10 @@ migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
prog.remove_instruction(last_ins);
if(last_ins->name() == "@return")
{
prog.remove_instruction(last_ins);
}
return prog;
}
......@@ -851,6 +854,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
......@@ -883,6 +887,7 @@ TEST_CASE(lstm_forward)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
......@@ -916,6 +921,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
......@@ -1022,6 +1028,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
......@@ -1102,6 +1109,7 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
......
......@@ -19,7 +19,7 @@ migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode
// remove the last identity instruction
auto last_ins = std::prev(prog.end());
if(last_ins->name() == "identity")
if(last_ins->name() == "@return")
{
prog.remove_instruction(last_ins);
}
......
......@@ -16,17 +16,21 @@
TEST_CASE(param_add)
{
auto create_program_float = [] {
auto create_program_float = [](bool add_return = false) {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
p.add_instruction(migraphx::op::add{}, p1, p2);
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
if(add_return)
{
p.add_return({sum});
}
return p;
};
auto create_program_half = [] {
auto create_program_half = [](bool add_return = false) {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
......@@ -34,7 +38,11 @@ TEST_CASE(param_add)
auto p2 = p.add_parameter("y", s);
auto hp2 = p.insert_instruction(std::next(p2), migraphx::op::convert{}, p2);
auto hs = p.add_instruction(migraphx::op::add{}, hp1, hp2);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
auto res = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
if(add_return)
{
p.add_return({res});
}
return p;
};
......@@ -54,6 +62,22 @@ TEST_CASE(param_add)
migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_float(true);
auto p2 = create_program_half(true);
migraphx::quantize_fp16(p1);
EXPECT(p1 == p2);
}
{
auto p1 = create_program_float(true);
auto p2 = create_program_half(true);
migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2);
}
}
TEST_CASE(param_add_sub)
......@@ -556,7 +580,7 @@ TEST_CASE(dot_int32_one_arg)
TEST_CASE(dot_int32)
{
auto create_program = [] {
auto create_program = [](bool add_return = false) {
migraphx::program p;
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
......@@ -565,12 +589,16 @@ TEST_CASE(dot_int32)
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
p.add_instruction(migraphx::op::dot{2.0f, 5.5f}, pa, pb, pc);
auto res = p.add_instruction(migraphx::op::dot{2.0f, 5.5f}, pa, pb, pc);
if(add_return)
{
p.add_return({res});
}
return p;
};
auto create_int8_quantized_prog = [] {
auto create_int8_quantized_prog = [](bool add_return = false) {
migraphx::program p;
migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
......@@ -614,7 +642,11 @@ TEST_CASE(dot_int32)
auto beta = p.add_literal(migraphx::literal(fc->get_shape(), v_beta));
auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, fc);
auto f_res = p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c);
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res);
auto res = p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res);
if(add_return)
{
p.add_return({res});
}
return p;
};
......@@ -624,8 +656,12 @@ TEST_CASE(dot_int32)
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
auto p_ret = create_program(true);
migraphx::quantize_int8_impl(p_ret, quant_params, {"dot"});
auto qp_ret = create_int8_quantized_prog(true);
EXPECT(p_ret == qp_ret);
}
TEST_CASE(dot_float_convert)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment