"test/vscode:/vscode.git/clone" did not exist on "60468da4e2d7bda65ee3ad04857d7e29db9396af"
Unverified Commit bc52a8a8 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Inline subgraph (#802)



* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* unify the compute function signature

* clang format

* make another change

* unify the compute function

* clang format

* remove unnecessary code

* more refinement about the operator compute funciton

* clang format

* add an overload function

* clang format

* add support for axes inputs for sequeeze/unsqueeze/reduce_sum

* clang format

* fix build problems

* backup code changes

* clang format

* Add tuple type to shape class

* Formatting

* fix a bug in parsing quantizelinear operator

* clang format

* fix a cppcheck error

* disable different versions of unit tests for different onnx version

* clang format

* upgrade onnx to 1.8

* update onnx to 1.8.1

* disable two more real models

* clang format

* Make data member private

* Formatting

* Add sub arguments

* Formatting

* Trun clang format off

* Disable clang-format

* fix review comments

* fix the function of assign axes in parsing the squeeze operator

* add unit tests and fix a bug

* clang format

* fix review comments

* clang format

* fix a build error

* backup code changes

* clang format

* add more unit tests and add parsing opset version

* clang format

* Improve visiting tuples

* Formatting

* fix cppcheck error

* adding installing the onnx package

* resolve no protobuf compiler

* add an inline subgraph pass

* clang format

* Add more argument tests

* Formatting

* Handle tuple in load

* Formatting

* code backup

* clang format

* Remove .o files

* Add tuple type to api

* Formatting

* fix build errors

* clang format

* code backup

* code backup

* add unit tests for the inline subgraph

* clang format

* refine the inline subgraph and parse if operator

* clang format

* fix cppcheck issue

* clang format

* add unit test for inline subgraph pass

* clang format

* fix format issue

* remove the context from the if operator

* clang format

* simplify the compute functions

* Fix tidy warnings

* fix cppcheck error

* clang format

* fix cppcheck error

* Fix tidy warnings

* fix a cppcheck error

* clang format

* Add a test for share method

* Formatting

* Add a test cpp_type

* add unit tests for more code coverage

* clang format

* add unit tests to have more code coverage

* clang format

* try a comment in jenkins build

* include the install onnnx line

* code backup

* reorder the dependenciesd installed

* refine dockerfile

* fix review comments

* clang format

* remove unnecessary overload function

* fix cppcheck error

* change back the argument test

* Suppress tidy warning

* add the operator get_tuple_elem

* clang format

* add get_tuple_elem to operator include file

* chang if to support multiple operation outputs

* clang format

* optimize inline subgraph

* clang format

* code backup

* clang format

* fix bug

* refine unit tests for tuple output of the if operator

* clang format

* refine a instruction replacement code

* add a unit test and sort all the unit tests alphabetically

* fix cppcheck error

* add more unit tests for multiple op outputs

* clang format

* fix cppcheck error

* Update pass manager to get modules after every pass

* more unit test to cover more scenarios

* clang format

* fixed a bug in a unit test

* add more tests

* clang format

* add more unit tests to have more code coverage

* fix a bug in a unit test

* Add program overload for module

* Formatting

* Hash modules for quicker lookup of modules

* Bump file version

* Add methods to remove modules

* Formatting

* add the tuple type to the support list

* Eliminate unused modules

* Formatting

* Fix test errors

* Foramtting

* Fix tidy issues

* fix problem related to inline subgraph

* clang format

* fix review comments

* fix review comments

* fix review comments

* fix review comments

* clang format

* fix a unit test

* one more code change

* remove an optimization related to the if operator

* clang format

* fix review comments
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e00479af
...@@ -30,6 +30,7 @@ add_library(migraphx ...@@ -30,6 +30,7 @@ add_library(migraphx
rewrite_pooling.cpp rewrite_pooling.cpp
env.cpp env.cpp
generate.cpp generate.cpp
inline_module.cpp
instruction.cpp instruction.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
...@@ -100,6 +101,7 @@ register_migraphx_ops( ...@@ -100,6 +101,7 @@ register_migraphx_ops(
flatten flatten
floor floor
gather gather
get_tuple_elem
greater greater
gru gru
identity identity
......
...@@ -158,6 +158,13 @@ struct check_shapes ...@@ -158,6 +158,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this;
}
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct inline_module
{
std::string name() const { return "inline_module"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP
#include "migraphx/errors.hpp"
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct get_tuple_elem
{
std::size_t index = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.index, "index"));
}
std::string name() const { return "get_tuple_elem"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).tuple_type();
const auto& sub_shapes = inputs.at(0).sub_shapes();
if(index >= sub_shapes.size())
{
MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " +
std::to_string(sub_shapes.size()));
}
return sub_shapes.at(index);
}
argument compute(const shape&, std::vector<argument> args) const
{
assert(args.size() == 1);
auto vec_args = args.at(0).get_sub_objects();
assert(index < vec_args.size());
return vec_args.at(index);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,14 +35,14 @@ struct if_op ...@@ -35,14 +35,14 @@ struct if_op
MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
} }
return out_shapes0.front(); return shape(out_shapes0);
} }
argument compute( argument compute(const shape&,
const std::vector<argument>& args, const std::vector<argument>& args,
const std::vector<module_ref>& mods, const std::vector<module_ref>& mods,
const std::function<std::vector<argument>( const std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{ {
auto cond = args.front().at<bool>(); auto cond = args.front().at<bool>();
module_ref mod = cond ? mods[0] : mods[1]; module_ref mod = cond ? mods[0] : mods[1];
...@@ -63,7 +63,7 @@ struct if_op ...@@ -63,7 +63,7 @@ struct if_op
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
auto results = run(mod, params); auto results = run(mod, params);
return results[0]; return argument{results};
} }
}; };
......
...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x, ...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
} }
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<1>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -188,14 +188,6 @@ auto compute_op(rank<2>, ...@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -207,50 +199,106 @@ template <class T> ...@@ -207,50 +199,106 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<2>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&) argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T, class F>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{ {
return compute_op(rank<2>{}, x, output_shape, input); return compute_op(rank<1>{}, x, output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
auto compute_op(rank<1>, auto compute_op(rank<3>,
const T& x, const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f)) F f) -> decltype(x.compute(output, inputs, module_args, f))
{ {
return x.compute(inputs, module_args, f); return x.compute(output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
argument auto compute_op(rank<2>,
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F) const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
...@@ -258,11 +306,13 @@ argument ...@@ -258,11 +306,13 @@ argument
template <class T, class F> template <class T, class F>
argument compute_op(const T& x, argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
{ {
return compute_op(rank<1>{}, x, inputs, module_args, f); return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
} }
template <class T> template <class T>
...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&) ...@@ -409,9 +459,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input) * input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>& * const; argument compute(const shape& output,const std::vector<argument>& input,const
* module_args,std::function<std::vector<argument>(module_ref& mdl, const * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream & * from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & * operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ; * x,const operation & y) ;
...@@ -555,14 +608,27 @@ struct operation ...@@ -555,14 +608,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input); return (*this).private_detail_te_get_handle().compute(output, input);
} }
argument compute( argument compute(const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run)); return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
} }
value to_value() const value to_value() const
...@@ -625,16 +691,23 @@ struct operation ...@@ -625,16 +691,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument virtual argument
compute(const std::vector<argument>& input, compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
const = 0; virtual value to_value() const = 0;
virtual value to_value() const = 0; virtual void from_value(const value& v) = 0;
virtual void from_value(const value& v) = 0; virtual value attributes() const = 0;
virtual value attributes() const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual bool operator==(const operation& y) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <class T> template <class T>
...@@ -828,25 +901,58 @@ struct operation ...@@ -828,25 +901,58 @@ struct operation
static auto private_detail_te_default_compute( static auto private_detail_te_default_compute(
char, char,
T&& private_detail_te_self, T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run))) -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{ {
return private_detail_te_self.compute(input, module_args, std::move(run)); return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
static argument private_detail_te_default_compute( static argument private_detail_te_default_compute(
float, float,
T&& private_detail_te_self, T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
{ {
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run)); return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
...@@ -994,16 +1100,29 @@ struct operation ...@@ -994,16 +1100,29 @@ struct operation
char(0), private_detail_te_value, output, input); char(0), private_detail_te_value, output, input);
} }
argument argument compute(
compute(const std::vector<argument>& input, const shape& output,
const std::vector<module_ref>& module_args, const std::vector<argument>& input,
std::function<std::vector<argument>( const std::vector<module_ref>& module_args,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) std::function<std::vector<argument>(
const override module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{ {
return private_detail_te_default_compute( return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run)); char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
} }
value to_value() const override value to_value() const override
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
......
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -150,14 +150,7 @@ void module::assign(const module& m) ...@@ -150,14 +150,7 @@ void module::assign(const module& m)
} }
else else
{ {
if(module_args.empty()) copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
} }
} }
......
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if> ...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
} }
migraphx::argument cond_arg = args.front()->eval(); std::string then_name = info.name + "_if";
// cond is not constant, need to create sub_modules module_ref then_mdl = parser.prog.create_module(then_name);
if(cond_arg.empty())
{
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph std::string else_name = info.name + "_else";
parser.parse_graph(then_mdl, then_graph); module_ref else_mdl = parser.prog.create_module(else_name);
// parse_the else sub_graph // parse the then sub_graph
parser.parse_graph(else_mdl, else_graph); parser.parse_graph(then_mdl, then_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); // parse_the else sub_graph
auto else_out_shapes = else_mdl->get_output_shapes(); parser.parse_graph(else_mdl, else_graph);
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
return {ret}; if(not std::equal(then_out_shapes.begin(),
} then_out_shapes.end(),
else else_out_shapes.begin(),
else_out_shapes.end()))
{ {
auto* mod = info.mod; MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
// then branch }
if(cond_arg.at<bool>())
{
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
// if instruction auto out_s = if_ret->get_shape();
instruction_ref ret_ins = std::prev(mod->end()); assert(out_s.type() == shape::tuple_type);
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs; const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
out_inss.push_back(ret);
} }
return out_inss;
} }
}; };
......
...@@ -243,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -243,20 +243,10 @@ std::vector<argument> generic_eval(const module* mod,
return generic_eval(smod, ctx, inputs, results, trace); return generic_eval(smod, ctx, inputs, results, trace);
}; };
if(not mod_args.empty()) results.emplace(ins, trace(ins, [&] {
{ return ins->normalized_operator().compute(
results.emplace(ins, trace(ins, [&] { ctx, ins->get_shape(), values, mod_args, module_eval);
return ins->normalized_operator().compute( }));
values, mod_args, module_eval);
}));
}
else
{
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values);
}));
}
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
...@@ -51,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -51,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::tuple_type);
// clang-format off // clang-format off
return return
{ {
...@@ -68,6 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -68,6 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::inline_module{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(cannot_inline_both)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(cannot_inline_one)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape s{migraphx::shape::float_type, {5}};
auto cond = mm->add_parameter("cond", cond_s);
auto x = mm->add_parameter("x", s);
auto* then_mod = p.create_module("If_0_if");
std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod->add_return({l1, x});
auto* else_mod = p.create_module("If_0_else");
std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2);
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(inline_if_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, sm);
then_mod->add_outline(s);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("add"), x, sm);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_else_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
else_mod->add_parameter("e", s);
else_mod->add_literal(migraphx::literal(s, ones));
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_cond0_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto m = mm->add_instruction(migraphx::make_op("mul"), y2, ly);
mm->add_return({m});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_true_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_false_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape sd{migraphx::shape::float_type, {1}};
mm->add_literal(migraphx::literal(sd, {1}));
mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -24,33 +24,102 @@ migraphx::program create_program() ...@@ -24,33 +24,102 @@ migraphx::program create_program()
return p; return p;
} }
TEST_CASE(module_ins_clear) TEST_CASE(calc_implict_deps)
{ {
migraphx::program p1 = create_program(); migraphx::program p;
migraphx::program p2; auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
p2 = p1; auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
EXPECT(p1 == p2); auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
} }
TEST_CASE(module_print_graph) TEST_CASE(module_annotate)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1; std::stringstream ss1;
mm1->print_graph(ss1, true); mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2; std::stringstream ss2;
mm2->print_graph(ss2, true); mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(module_print_cpp) TEST_CASE(module_print_cpp)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
...@@ -68,43 +137,23 @@ TEST_CASE(module_print_cpp) ...@@ -68,43 +137,23 @@ TEST_CASE(module_print_cpp)
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_annotate) TEST_CASE(module_print_graph)
{ {
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module(); auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1; std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); mm1->print_graph(ss1, true);
std::stringstream ss2; std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str()); EXPECT(ss1.str() == ss2.str());
} }
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(program_module_assign) TEST_CASE(program_module_assign)
{ {
migraphx::program p; migraphx::program p;
...@@ -204,51 +253,4 @@ TEST_CASE(submodule_copy) ...@@ -204,51 +253,4 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules()); EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
} }
TEST_CASE(calc_implict_deps)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
else_mod->add_return({a2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1900,6 +1900,77 @@ def if_then_test(): ...@@ -1900,6 +1900,77 @@ def if_then_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_tuple_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1, 4])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [3, 4])
cond_input = onnx.helper.make_tensor_value_info('cond',
onnx.TensorProto.BOOL, [])
then_out0 = onnx.helper.make_tensor_value_info('then_out0',
onnx.TensorProto.FLOAT,
[1, 4])
then_out1 = onnx.helper.make_tensor_value_info('then_out1',
onnx.TensorProto.FLOAT,
[3, 4])
else_out0 = onnx.helper.make_tensor_value_info('else_out0',
onnx.TensorProto.FLOAT,
[1, 4])
else_out1 = onnx.helper.make_tensor_value_info('else_out1',
onnx.TensorProto.FLOAT,
[3, 4])
one = np.ones([1]).astype(np.float)
one_tensor = helper.make_tensor(name='one',
data_type=TensorProto.FLOAT,
dims=one.shape,
vals=one.flatten().astype(np.float32))
two = np.array([2]).astype(np.float)
two_tensor = helper.make_tensor(name='two',
data_type=TensorProto.FLOAT,
dims=two.shape,
vals=two.flatten().astype(np.float32))
three = np.array([3]).astype(np.float)
three_tensor = helper.make_tensor(name='three',
data_type=TensorProto.FLOAT,
dims=three.shape,
vals=three.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'one'],
outputs=['then_out0'])
then_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'two'],
outputs=['then_out1'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['x', 'three'],
outputs=['else_out0'])
else_add_node = onnx.helper.make_node('Add',
inputs=['y', 'three'],
outputs=['else_out1'])
then_body = onnx.helper.make_graph([then_add_node, then_mul_node],
'then_body', [], [then_out0, then_out1])
else_body = onnx.helper.make_graph([else_mul_node, else_add_node],
'else_body', [], [else_out0, else_out1])
res0 = onnx.helper.make_tensor_value_info('res0', TensorProto.FLOAT, [])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res0', 'res1'],
then_branch=then_body,
else_branch=else_body)
return ([node], [cond_input, x,
y], [res0, res1], [one_tensor, two_tensor, three_tensor])
@onnx_test @onnx_test
def imagescaler_test(): def imagescaler_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
...@@ -1396,17 +1396,25 @@ TEST_CASE(if_else_test) ...@@ -1396,17 +1396,25 @@ TEST_CASE(if_else_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0})); auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589}; std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
mm->add_parameter("x", s); auto* else_mod = p.create_module("If_5_else");
auto y = mm->add_parameter("y", s); auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary); std::ifstream ifs("if_else_test.onnx", std::ios::binary);
...@@ -1418,7 +1426,6 @@ TEST_CASE(if_else_test) ...@@ -1418,7 +1426,6 @@ TEST_CASE(if_else_test)
ifs.close(); ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {}); auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1444,7 +1451,8 @@ TEST_CASE(if_literal_test) ...@@ -1444,7 +1451,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_literal_test.onnx"); auto prog = migraphx::parse_onnx("if_literal_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1483,7 +1491,8 @@ TEST_CASE(if_param_test) ...@@ -1483,7 +1491,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_param_test.onnx"); auto prog = migraphx::parse_onnx("if_param_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1516,7 +1525,9 @@ TEST_CASE(if_pl_test) ...@@ -1516,7 +1525,9 @@ TEST_CASE(if_pl_test)
else_mod->add_return({l2, a2}); else_mod->add_return({l2, a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_pl_test.onnx"); auto prog = migraphx::parse_onnx("if_pl_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1527,21 +1538,70 @@ TEST_CASE(if_then_test) ...@@ -1527,21 +1538,70 @@ TEST_CASE(if_then_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1})); auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto* then_mod = p.create_module("If_5_if");
mm->add_parameter("y", s); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto r = mm->add_instruction(migraphx::make_op("add"), x, l1); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_test.onnx"); auto prog = migraphx::parse_onnx("if_then_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
auto prog = migraphx::parse_onnx("if_tuple_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test) ...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test)
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test) ...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test)
} }
} }
TEST_CASE(if_tuple_test)
{
auto run_prog = [](bool cond) {
migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {1, 4}};
migraphx::shape ys{migraphx::shape::float_type, {3, 4}};
migraphx::shape cond_s{migraphx::shape::bool_type};
std::vector<float> x_data(xs.elements(), 1.0f);
std::vector<float> y_data(ys.elements(), 2.0f);
std::vector<char> cond_data{static_cast<char>(cond)};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(xs, x_data.data());
pp["y"] = migraphx::argument(ys, y_data.data());
pp["cond"] = migraphx::argument(cond_s, cond_data.data());
auto results = p.eval(pp);
std::vector<std::vector<float>> rets;
for(const auto& arg : results)
{
std::vector<float> vec;
arg.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
rets.push_back(vec);
}
return rets;
};
// then branch
{
auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
// else branch
{
auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
}
TEST_CASE(instance_norm_test) TEST_CASE(instance_norm_test)
{ {
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
......
...@@ -68,6 +68,43 @@ TEST_CASE(batch_norm_inference_shape) ...@@ -68,6 +68,43 @@ TEST_CASE(batch_norm_inference_shape)
throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars); throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars);
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
...@@ -106,6 +143,24 @@ TEST_CASE(convolution_shape) ...@@ -106,6 +143,24 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d); throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
} }
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape) TEST_CASE(deconvolution_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
...@@ -137,141 +192,6 @@ TEST_CASE(deconvolution_shape) ...@@ -137,141 +192,6 @@ TEST_CASE(deconvolution_shape)
weights_3d); weights_3d);
} }
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
}
TEST_CASE(pooling_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(
migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
input);
expect_shape(
output,
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1,
migraphx::make_op("pooling",
{{"mode", "max"},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(inconsistent_attr_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
}
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
}
}
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
...@@ -300,173 +220,48 @@ TEST_CASE(flatten_shape) ...@@ -300,173 +220,48 @@ TEST_CASE(flatten_shape)
throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input); throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
} }
TEST_CASE(slice_shape) TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; {
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
input); int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op( migraphx::make_op("gather", {{"axis", axis}}),
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}), input,
input); indices);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, }
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
TEST_CASE(multibroadcast)
{
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; migraphx::shape indices{migraphx::shape::int32_type, {1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {5, 1}}; migraphx::shape indices{migraphx::shape::int32_type};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type}; migraphx::shape indices{migraphx::shape::int32_type};
...@@ -512,219 +307,546 @@ TEST_CASE(gather) ...@@ -512,219 +307,546 @@ TEST_CASE(gather)
} }
} }
template <class T> // 3 input arguments
void test_softmax_variations() TEST_CASE(gemm)
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
} migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
} }
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(test_argmax)
{
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 0}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 1}}), migraphx::shape s_m3{migraphx::shape::float_type, {8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 2}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 3}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input); migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
}
TEST_CASE(test_argmin)
{
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmin", {{"axis", 0}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 1}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 2}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 3}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input); migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
} }
TEST_CASE(test_scalar) TEST_CASE(get_tuple_elem_test)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}}; migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1); migraphx::shape s2{migraphx::shape::int32_type, {5, 6}};
migraphx::shape s_tuple({s0, s1, s2});
expect_shape(s0, migraphx::make_op("get_tuple_elem", {{"index", 0}}), s_tuple);
expect_shape(s1, migraphx::make_op("get_tuple_elem", {{"index", 1}}), s_tuple);
expect_shape(s2, migraphx::make_op("get_tuple_elem", {{"index", 2}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 3}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s0);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 1}}), s1);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s2);
} }
TEST_CASE(test_scalar_nelemnts) TEST_CASE(gru)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; {
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input); std::size_t batch_size = 2;
} std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; migraphx::shape r_shape{migraphx::shape::float_type,
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1); {num_dirct, 3 * hidden_size, hidden_size}};
} migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_squeeze_negative_axis) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1); "gru",
} {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_squeeze_wrong_axis) {
{ std::size_t batch_size = 2;
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; std::size_t seq_len = 2;
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); std::size_t hidden_size = 4;
} std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze_all) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {1}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape s2{migraphx::shape::float_type}; migraphx::shape r_shape{migraphx::shape::float_type,
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); {num_dirct, 3 * hidden_size, hidden_size}};
} migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze_scalar) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1); "gru",
} {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_unsqueeze_scalar_tensor1) {
{ std::size_t batch_size = 2;
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}}; std::size_t seq_len = 2;
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); std::size_t hidden_size = 4;
} std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
TEST_CASE(test_unsqueeze_scalar_tensor2) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}}; {num_dirct, 3 * hidden_size, input_size}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); migraphx::shape r_shape{migraphx::shape::float_type,
} {num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); "gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
} }
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(inconsistent_attr_shape)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
} }
template <class T> template <class T>
void test_reduce_ops() void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(lstm)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); migraphx::shape{migraphx::shape::float_type,
} {seq_len, num_dirct, batch_size, hidden_size}},
{ migraphx::make_op(
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; "lstm",
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
throws_shape(T{{4}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
...@@ -825,523 +947,269 @@ TEST_CASE(matmul) ...@@ -825,523 +947,269 @@ TEST_CASE(matmul)
} }
} }
// 3 input arguments TEST_CASE(multibroadcast)
TEST_CASE(gemm)
{ {
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
} input);
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}}, migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("dot"), input);
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape s_m3{migraphx::shape::float_type}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("quant_dot"), migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
s_m1, input);
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
s_m1, input);
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2); throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
} }
}
TEST_CASE(quant_dot_3args)
{
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}}; throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {3, 4}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}}; throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3); }
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
} }
} }
TEST_CASE(rnn) TEST_CASE(pooling_shape)
{ {
{ migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("pooling",
std::size_t input_size = 3; {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
std::size_t num_dirct = 1; input);
float clip = 0.0f; expect_shape(
output,
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::make_op(
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; "pooling",
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; {{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; input);
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
migraphx::shape{migraphx::shape::float_type, expect_shape(output1,
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::make_op("pooling",
migraphx::make_op( {{"mode", "max"},
"rnn", {"padding", {0, 0}},
{{"hidden_size", hidden_size}, {"stride", {3, 3}},
{"actv_func", {"lengths", {1, 1}},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, {"ceil_mode", true}}),
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, input);
{"clip", clip}}), }
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(prefix_scan_sum)
{
{ {
std::size_t batch_size = 2; migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}),
std::size_t input_size = 3; s);
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s{migraphx::shape::float_type, {1, 2}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}),
std::size_t input_size = 3; s);
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
}
{ TEST_CASE(quant_convolution_shape)
std::size_t batch_size = 2; {
std::size_t seq_len = 2; migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
std::size_t hidden_size = 4; migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
std::size_t input_size = 3; migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
std::size_t num_dirct = 1; expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
float clip = 0.0f; throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::make_op( migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
"rnn", throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
{{"hidden_size", hidden_size + 1}, throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
{"actv_func", throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, }
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
std::size_t input_size = 3; migraphx::make_op("quant_dot"),
std::size_t num_dirct = 1; s_m1,
float clip = 0.0f; s_m2);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
std::size_t hidden_size = 4; expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
std::size_t input_size = 3; migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
std::size_t num_dirct = 2; s_m1,
float clip = 0.0f; s_m2);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( {
migraphx::make_op( migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
"rnn", migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
{{"hidden_size", hidden_size}, throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(gru) TEST_CASE(quant_dot_3args)
{ {
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
std::size_t input_size = 3; expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
std::size_t num_dirct = 1; migraphx::make_op("quant_dot"),
float clip = 0.0f; s_m1,
s_m2,
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; s_m3);
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
std::size_t input_size = 3; throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
}
template <class T>
void test_reduce_ops()
{
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
std::size_t hidden_size = 4; }
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
std::size_t hidden_size = 4; }
std::size_t input_size = 3; {
std::size_t num_dirct = 1; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
float clip = 0.0f; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
}
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
migraphx::shape w_shape{migraphx::shape::float_type, TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( TEST_CASE(reshape_shape)
migraphx::make_op( {
"gru", migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
{{"hidden_size", hidden_size}, for(auto&& new_shape :
{"actv_func", std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, {
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, std::vector<std::size_t> lens(new_shape.size());
{"clip", clip}}), std::copy(new_shape.begin(), new_shape.end(), lens.begin());
in_shape, migraphx::shape output{migraphx::shape::float_type, lens};
w_shape, expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
r_shape,
b_shape,
ih_shape);
} }
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{ {
std::size_t batch_size = 2; throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
std::size_t seq_len = 2; }
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
migraphx::shape w_shape{migraphx::shape::float_type, {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{num_dirct, 3 * hidden_size, input_size}}; {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
migraphx::shape r_shape{migraphx::shape::float_type, {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{num_dirct, 3 * hidden_size, hidden_size}}; {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
throws_shape( for(auto& it : minus1_tests)
migraphx::make_op( {
"gru", expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(lstm) TEST_CASE(rnn)
{ {
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
...@@ -1352,16 +1220,16 @@ TEST_CASE(lstm) ...@@ -1352,16 +1220,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
{num_dirct, 3 * hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
{num_dirct, 3 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1369,7 +1237,9 @@ TEST_CASE(lstm) ...@@ -1369,7 +1237,9 @@ TEST_CASE(lstm)
{"clip", clip}}), {"clip", clip}}),
in_shape, in_shape,
w_shape, w_shape,
r_shape); r_shape,
b_shape,
ih_shape);
} }
{ {
...@@ -1381,18 +1251,16 @@ TEST_CASE(lstm) ...@@ -1381,18 +1251,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1414,18 +1282,16 @@ TEST_CASE(lstm) ...@@ -1414,18 +1282,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1447,16 +1313,14 @@ TEST_CASE(lstm) ...@@ -1447,16 +1313,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size + 1}, {{"hidden_size", hidden_size + 1},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1478,16 +1342,14 @@ TEST_CASE(lstm) ...@@ -1478,16 +1342,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1509,16 +1371,14 @@ TEST_CASE(lstm) ...@@ -1509,16 +1371,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1532,21 +1392,176 @@ TEST_CASE(lstm) ...@@ -1532,21 +1392,176 @@ TEST_CASE(lstm)
} }
} }
TEST_CASE(prefix_scan_sum) TEST_CASE(slice_shape)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(test_argmax)
{ {
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
throws_shape( expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}), migraphx::make_op("argmax", {{"axis", 0}}),
s); input);
} }
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 2}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
throws_shape( expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}), migraphx::make_op("argmax", {{"axis", 1}}),
s); input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmin", {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmin", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmin", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmin", {{"axis", 3}}),
input);
} }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
}
}
TEST_CASE(test_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
}
TEST_CASE(test_scalar_nelemnts)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
}
TEST_CASE(test_squeeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
}
TEST_CASE(test_squeeze_all)
{
migraphx::shape s1{migraphx::shape::float_type, {1}};
migraphx::shape s2{migraphx::shape::float_type};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_squeeze_wrong_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
} }
TEST_CASE(step_test) TEST_CASE(step_test)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/module.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -115,4 +116,36 @@ TEST_CASE(ops) ...@@ -115,4 +116,36 @@ TEST_CASE(ops)
EXPECT(names.size() > 1); EXPECT(names.size() > 1);
} }
TEST_CASE(rnn)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(s, data1.data());
migraphx::argument a2(s, data2.data());
auto op = migraphx::make_op("rnn");
EXPECT(test::throws([&] { op.compute(s, {a1, a2}); }));
}
TEST_CASE(if_op)
{
migraphx::shape s{migraphx::shape::bool_type, {1}};
std::vector<char> data = {1};
migraphx::argument cond(s, data.data());
migraphx::shape sd{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(sd, data1.data());
migraphx::argument a2(sd, data2.data());
migraphx::module m("name");
auto l = m.add_literal(migraphx::literal(sd, data1));
m.add_return({l});
auto op = migraphx::make_op("add");
EXPECT(test::throws([&] { op.compute(s, {cond, a1, a2}, {&m, &m}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment