Unverified Commit bc52a8a8 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Inline subgraph (#802)



* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* unify the compute function signature

* clang format

* make another change

* unify the compute function

* clang format

* remove unnecessary code

* more refinement about the operator compute funciton

* clang format

* add an overload function

* clang format

* add support for axes inputs for sequeeze/unsqueeze/reduce_sum

* clang format

* fix build problems

* backup code changes

* clang format

* Add tuple type to shape class

* Formatting

* fix a bug in parsing quantizelinear operator

* clang format

* fix a cppcheck error

* disable different versions of unit tests for different onnx version

* clang format

* upgrade onnx to 1.8

* update onnx to 1.8.1

* disable two more real models

* clang format

* Make data member private

* Formatting

* Add sub arguments

* Formatting

* Trun clang format off

* Disable clang-format

* fix review comments

* fix the function of assign axes in parsing the squeeze operator

* add unit tests and fix a bug

* clang format

* fix review comments

* clang format

* fix a build error

* backup code changes

* clang format

* add more unit tests and add parsing opset version

* clang format

* Improve visiting tuples

* Formatting

* fix cppcheck error

* adding installing the onnx package

* resolve no protobuf compiler

* add an inline subgraph pass

* clang format

* Add more argument tests

* Formatting

* Handle tuple in load

* Formatting

* code backup

* clang format

* Remove .o files

* Add tuple type to api

* Formatting

* fix build errors

* clang format

* code backup

* code backup

* add unit tests for the inline subgraph

* clang format

* refine the inline subgraph and parse if operator

* clang format

* fix cppcheck issue

* clang format

* add unit test for inline subgraph pass

* clang format

* fix format issue

* remove the context from the if operator

* clang format

* simplify the compute functions

* Fix tidy warnings

* fix cppcheck error

* clang format

* fix cppcheck error

* Fix tidy warnings

* fix a cppcheck error

* clang format

* Add a test for share method

* Formatting

* Add a test cpp_type

* add unit tests for more code coverage

* clang format

* add unit tests to have more code coverage

* clang format

* try a comment in jenkins build

* include the install onnnx line

* code backup

* reorder the dependenciesd installed

* refine dockerfile

* fix review comments

* clang format

* remove unnecessary overload function

* fix cppcheck error

* change back the argument test

* Suppress tidy warning

* add the operator get_tuple_elem

* clang format

* add get_tuple_elem to operator include file

* chang if to support multiple operation outputs

* clang format

* optimize inline subgraph

* clang format

* code backup

* clang format

* fix bug

* refine unit tests for tuple output of the if operator

* clang format

* refine a instruction replacement code

* add a unit test and sort all the unit tests alphabetically

* fix cppcheck error

* add more unit tests for multiple op outputs

* clang format

* fix cppcheck error

* Update pass manager to get modules after every pass

* more unit test to cover more scenarios

* clang format

* fixed a bug in a unit test

* add more tests

* clang format

* add more unit tests to have more code coverage

* fix a bug in a unit test

* Add program overload for module

* Formatting

* Hash modules for quicker lookup of modules

* Bump file version

* Add methods to remove modules

* Formatting

* add the tuple type to the support list

* Eliminate unused modules

* Formatting

* Fix test errors

* Foramtting

* Fix tidy issues

* fix problem related to inline subgraph

* clang format

* fix review comments

* fix review comments

* fix review comments

* fix review comments

* clang format

* fix a unit test

* one more code change

* remove an optimization related to the if operator

* clang format

* fix review comments
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e00479af
...@@ -30,6 +30,7 @@ add_library(migraphx ...@@ -30,6 +30,7 @@ add_library(migraphx
rewrite_pooling.cpp rewrite_pooling.cpp
env.cpp env.cpp
generate.cpp generate.cpp
inline_module.cpp
instruction.cpp instruction.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
...@@ -100,6 +101,7 @@ register_migraphx_ops( ...@@ -100,6 +101,7 @@ register_migraphx_ops(
flatten flatten
floor floor
gather gather
get_tuple_elem
greater greater
gru gru
identity identity
......
...@@ -158,6 +158,13 @@ struct check_shapes ...@@ -158,6 +158,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this;
}
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct inline_module
{
std::string name() const { return "inline_module"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,14 +35,14 @@ struct if_op ...@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
} }
return out_shapes0.front(); return shape(out_shapes0);
} }
argument compute( argument compute(const shape&,
const std::vector<argument>& args, const std::vector<argument>& args,
const std::vector<module_ref>& mods, const std::vector<module_ref>& mods,
const std::function<std::vector<argument>( const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{ {
auto cond = args.front().at<bool>(); auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1]; module_ref mod = cond ? mods[0] : mods[1];
...@@ -63,7 +63,7 @@ struct if_op ...@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params); auto results = run(mod, params);
return results[0]; return argument{results};
} }
}; };
......
...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x, ...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
} }
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<1>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -188,14 +188,6 @@ auto compute_op(rank<2>, ...@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -207,50 +199,106 @@ template <class T> ...@@ -207,50 +199,106 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<2>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&) argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T, class F>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{ {
return compute_op(rank<2>{}, x, output_shape, input); return compute_op(rank<1>{}, x, output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
auto compute_op(rank<1>, auto compute_op(rank<3>,
const T& x, const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f)) F f) -> decltype(x.compute(output, inputs, module_args, f))
{ {
return x.compute(inputs, module_args, f); return x.compute(output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
argument auto compute_op(rank<2>,
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F) const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
...@@ -258,11 +306,13 @@ argument ...@@ -258,11 +306,13 @@ argument
template <class T, class F> template <class T, class F>
argument compute_op(const T& x, argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
{ {
return compute_op(rank<1>{}, x, inputs, module_args, f); return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
} }
template <class T> template <class T>
...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&) ...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input) * input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>& * const; argument compute(const shape& output,const std::vector<argument>& input,const
* module_args,std::function<std::vector<argument>(module_ref& mdl, const * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream & * from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & * operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ; * x,const operation & y) ;
...@@ -555,14 +608,27 @@ struct operation ...@@ -555,14 +608,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input); return (*this).private_detail_te_get_handle().compute(output, input);
} }
argument compute( argument compute(const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run)); return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
} }
value to_value() const value to_value() const
...@@ -625,16 +691,23 @@ struct operation ...@@ -625,16 +691,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument virtual argument
compute(const std::vector<argument>& input, compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
const = 0; virtual value to_value() const = 0;
virtual value to_value() const = 0; virtual void from_value(const value& v) = 0;
virtual void from_value(const value& v) = 0; virtual value attributes() const = 0;
virtual value attributes() const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual bool operator==(const operation& y) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <class T> template <class T>
...@@ -828,25 +901,58 @@ struct operation ...@@ -828,25 +901,58 @@ struct operation
static auto private_detail_te_default_compute( static auto private_detail_te_default_compute(
char, char,
T&& private_detail_te_self, T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run))) -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{ {
return private_detail_te_self.compute(input, module_args, std::move(run)); return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
static argument private_detail_te_default_compute( static argument private_detail_te_default_compute(
float, float,
T&& private_detail_te_self, T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
{ {
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run)); return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
...@@ -994,16 +1100,29 @@ struct operation ...@@ -994,16 +1100,29 @@ struct operation
char(0), private_detail_te_value, output, input); char(0), private_detail_te_value, output, input);
} }
argument argument compute(
compute(const std::vector<argument>& input, const shape& output,
const std::vector<module_ref>& module_args, const std::vector<argument>& input,
std::function<std::vector<argument>( const std::vector<module_ref>& module_args,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) std::function<std::vector<argument>(
const override module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{ {
return private_detail_te_default_compute( return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run)); char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
} }
value to_value() const override value to_value() const override
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
......
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -150,14 +150,7 @@ void module::assign(const module& m) ...@@ -150,14 +150,7 @@ void module::assign(const module& m)
} }
else else
{ {
if(module_args.empty()) copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
} }
} }
......
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if> ...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
} }
migraphx::argument cond_arg = args.front()->eval(); std::string then_name = info.name + "_if";
// cond is not constant, need to create sub_modules module_ref then_mdl = parser.prog.create_module(then_name);
if(cond_arg.empty())
{
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph std::string else_name = info.name + "_else";
parser.parse_graph(then_mdl, then_graph); module_ref else_mdl = parser.prog.create_module(else_name);
// parse_the else sub_graph // parse the then sub_graph
parser.parse_graph(else_mdl, else_graph); parser.parse_graph(then_mdl, then_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); // parse_the else sub_graph
auto else_out_shapes = else_mdl->get_output_shapes(); parser.parse_graph(else_mdl, else_graph);
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
return {ret}; if(not std::equal(then_out_shapes.begin(),
} then_out_shapes.end(),
else else_out_shapes.begin(),
else_out_shapes.end()))
{ {
auto* mod = info.mod; MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
// then branch }
if(cond_arg.at<bool>())
{
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
// if instruction auto out_s = if_ret->get_shape();
instruction_ref ret_ins = std::prev(mod->end()); assert(out_s.type() == shape::tuple_type);
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs; const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
out_inss.push_back(ret);
} }
return out_inss;
} }
}; };
......
...@@ -243,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -243,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod,
return generic_eval(smod, ctx, inputs, results, trace); return generic_eval(smod, ctx, inputs, results, trace);
}; };
if(not mod_args.empty()) results.emplace(ins, trace(ins, [&] {
{ return ins->normalized_operator().compute(
results.emplace(ins, trace(ins, [&] { ctx, ins->get_shape(), values, mod_args, module_eval);
return ins->normalized_operator().compute( }));
values, mod_args, module_eval);
}));
}
else
{
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values);
}));
}
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
...@@ -51,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -51,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::tuple_type);
// clang-format off // clang-format off
return return
{ {
...@@ -68,6 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -68,6 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
......
This diff is collapsed.
...@@ -24,33 +24,102 @@ migraphx::program create_program() ...@@ -24,33 +24,102 @@ migraphx::program create_program()
return p; return p;
} }
TEST_CASE(module_ins_clear) TEST_CASE(calc_implict_deps)
{ {
migraphx::program p1 = create_program(); migraphx::program p;
migraphx::program p2; auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
p2 = p1; auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
EXPECT(p1 == p2); auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
} }
TEST_CASE(module_print_graph) TEST_CASE(module_annotate)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1; std::stringstream ss1;
mm1->print_graph(ss1, true); mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2; std::stringstream ss2;
mm2->print_graph(ss2, true); mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(module_print_cpp) TEST_CASE(module_print_cpp)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
...@@ -68,43 +137,23 @@ TEST_CASE(module_print_cpp) ...@@ -68,43 +137,23 @@ TEST_CASE(module_print_cpp)
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_annotate) TEST_CASE(module_print_graph)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1; std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); mm1->print_graph(ss1, true);
std::stringstream ss2; std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(program_module_assign) TEST_CASE(program_module_assign)
{ {
migraphx::program p; migraphx::program p;
...@@ -204,51 +253,4 @@ TEST_CASE(submodule_copy) ...@@ -204,51 +253,4 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules()); EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
} }
TEST_CASE(calc_implict_deps)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
else_mod->add_return({a2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1900,6 +1900,77 @@ def if_then_test(): ...@@ -1900,6 +1900,77 @@ def if_then_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_tuple_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1, 4])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [3, 4])
cond_input = onnx.helper.make_tensor_value_info('cond',
onnx.TensorProto.BOOL, [])
then_out0 = onnx.helper.make_tensor_value_info('then_out0',
onnx.TensorProto.FLOAT,
[1, 4])
then_out1 = onnx.helper.make_tensor_value_info('then_out1',
onnx.TensorProto.FLOAT,
[3, 4])
else_out0 = onnx.helper.make_tensor_value_info('else_out0',
onnx.TensorProto.FLOAT,
[1, 4])
else_out1 = onnx.helper.make_tensor_value_info('else_out1',
onnx.TensorProto.FLOAT,
[3, 4])
one = np.ones([1]).astype(np.float)
one_tensor = helper.make_tensor(name='one',
data_type=TensorProto.FLOAT,
dims=one.shape,
vals=one.flatten().astype(np.float32))
two = np.array([2]).astype(np.float)
two_tensor = helper.make_tensor(name='two',
data_type=TensorProto.FLOAT,
dims=two.shape,
vals=two.flatten().astype(np.float32))
three = np.array([3]).astype(np.float)
three_tensor = helper.make_tensor(name='three',
data_type=TensorProto.FLOAT,
dims=three.shape,
vals=three.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'one'],
outputs=['then_out0'])
then_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'two'],
outputs=['then_out1'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['x', 'three'],
outputs=['else_out0'])
else_add_node = onnx.helper.make_node('Add',
inputs=['y', 'three'],
outputs=['else_out1'])
then_body = onnx.helper.make_graph([then_add_node, then_mul_node],
'then_body', [], [then_out0, then_out1])
else_body = onnx.helper.make_graph([else_mul_node, else_add_node],
'else_body', [], [else_out0, else_out1])
res0 = onnx.helper.make_tensor_value_info('res0', TensorProto.FLOAT, [])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res0', 'res1'],
then_branch=then_body,
else_branch=else_body)
return ([node], [cond_input, x,
y], [res0, res1], [one_tensor, two_tensor, three_tensor])
@onnx_test @onnx_test
def imagescaler_test(): def imagescaler_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
...@@ -1396,17 +1396,25 @@ TEST_CASE(if_else_test) ...@@ -1396,17 +1396,25 @@ TEST_CASE(if_else_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0})); auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589}; std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
mm->add_parameter("x", s); auto* else_mod = p.create_module("If_5_else");
auto y = mm->add_parameter("y", s); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary); std::ifstream ifs("if_else_test.onnx", std::ios::binary);
...@@ -1418,7 +1426,6 @@ TEST_CASE(if_else_test) ...@@ -1418,7 +1426,6 @@ TEST_CASE(if_else_test)
ifs.close(); ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {}); auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1444,7 +1451,8 @@ TEST_CASE(if_literal_test) ...@@ -1444,7 +1451,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_literal_test.onnx"); auto prog = migraphx::parse_onnx("if_literal_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1483,7 +1491,8 @@ TEST_CASE(if_param_test) ...@@ -1483,7 +1491,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_param_test.onnx"); auto prog = migraphx::parse_onnx("if_param_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1516,7 +1525,9 @@ TEST_CASE(if_pl_test) ...@@ -1516,7 +1525,9 @@ TEST_CASE(if_pl_test)
else_mod->add_return({l2, a2}); else_mod->add_return({l2, a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_pl_test.onnx"); auto prog = migraphx::parse_onnx("if_pl_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1527,21 +1538,70 @@ TEST_CASE(if_then_test) ...@@ -1527,21 +1538,70 @@ TEST_CASE(if_then_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1})); auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto* then_mod = p.create_module("If_5_if");
mm->add_parameter("y", s); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto r = mm->add_instruction(migraphx::make_op("add"), x, l1); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_test.onnx"); auto prog = migraphx::parse_onnx("if_then_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
auto prog = migraphx::parse_onnx("if_tuple_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test) ...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test)
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test) ...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test)
} }
} }
TEST_CASE(if_tuple_test)
{
auto run_prog = [](bool cond) {
migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {1, 4}};
migraphx::shape ys{migraphx::shape::float_type, {3, 4}};
migraphx::shape cond_s{migraphx::shape::bool_type};
std::vector<float> x_data(xs.elements(), 1.0f);
std::vector<float> y_data(ys.elements(), 2.0f);
std::vector<char> cond_data{static_cast<char>(cond)};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(xs, x_data.data());
pp["y"] = migraphx::argument(ys, y_data.data());
pp["cond"] = migraphx::argument(cond_s, cond_data.data());
auto results = p.eval(pp);
std::vector<std::vector<float>> rets;
for(const auto& arg : results)
{
std::vector<float> vec;
arg.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
rets.push_back(vec);
}
return rets;
};
// then branch
{
auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
// else branch
{
auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
}
TEST_CASE(instance_norm_test) TEST_CASE(instance_norm_test)
{ {
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/module.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -115,4 +116,36 @@ TEST_CASE(ops) ...@@ -115,4 +116,36 @@ TEST_CASE(ops)
EXPECT(names.size() > 1); EXPECT(names.size() > 1);
} }
TEST_CASE(rnn)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(s, data1.data());
migraphx::argument a2(s, data2.data());
auto op = migraphx::make_op("rnn");
EXPECT(test::throws([&] { op.compute(s, {a1, a2}); }));
}
TEST_CASE(if_op)
{
migraphx::shape s{migraphx::shape::bool_type, {1}};
std::vector<char> data = {1};
migraphx::argument cond(s, data.data());
migraphx::shape sd{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(sd, data1.data());
migraphx::argument a2(sd, data2.data());
migraphx::module m("name");
auto l = m.add_literal(migraphx::literal(sd, data1));
m.add_return({l});
auto op = migraphx::make_op("add");
EXPECT(test::throws([&] { op.compute(s, {cond, a1, a2}, {&m, &m}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment