Unverified Commit e96d2b9a authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module operations (#741)



* code backup

* clang format

* code backup

* change the print function to support print instruction from other modules

* clang format

* fix cppcheck error

* fix cppcheck error

* chang to make submodule to be owned by program instead of modules

* clang format

* add an unit test for copy of a program with sub_modules

* clang format

* remove the parent_module member variable from the module class

* clang format

* add unit test for serialization of program with submodules

* clang format

* Fix bug where instructions were not printed when doing TRACE_EVAL

* clang storage of modules from map to list

* clang format

* Formatting

* change the program assign function

* clang format

* code cleanup

* clang format

* backup code

* clang format

* remove unnecessary code

* clang format

* add module print function

* code backup

* refine the module::print function

* refine the module:to_value() function

* code backup

* backup code changes

* code backup

* remove to_value and from_value function from the module class

* rename a function

* rename the if operator

* refine the if operator

* refine the print function of module and program

* code backup

* code backup

* fix a build warning

* fix overload of compute_shape function

* code backup

* fix unit test error

* fix cppcheck error

* fix the issue related to the overload of compute_shape

* fix review comments

* fix cppcheck error

* change the return name of if_op to be if

* clang format

* fix two unit tests

* clang format

* remove the unused compute_op function

* clang format

* fix clang tidy format

* clang format

* enhance the validate function and uncomment a unit test

* clang format

* remove unnecessary code

* clang format

* fix a hang issue related to the valid function

* fix an issue in replace_refs

* clang format

* fix review comments

* clang format

* fix cppcheck error

* add a unit test for more code coverage

* clang format

* fix review comments and add test for more code coverage

* clang format

* fix cppcheck error

* fix a cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix review comments

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7b22d210
...@@ -94,6 +94,7 @@ register_migraphx_ops( ...@@ -94,6 +94,7 @@ register_migraphx_ops(
greater greater
gru gru
identity identity
if_op
im2col im2col
leaky_relu leaky_relu
less less
......
...@@ -34,19 +34,29 @@ typedef enum { ...@@ -34,19 +34,29 @@ typedef enum {
} migraphx_status; } migraphx_status;
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum { typedef enum {
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct typedef struct
{ {
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy; bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math; bool fast_math;
} migraphx_compile_options; } migraphx_compile_options;
/// Options for saving and loading files
typedef struct typedef struct
{ {
/// Format to be used for file. It can either be json or msgpack
const char* format; const char* format;
} migraphx_file_options; } migraphx_file_options;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/erase.hpp> #include <migraphx/erase.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -14,6 +15,9 @@ namespace migraphx { ...@@ -14,6 +15,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args); std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
struct instruction struct instruction
...@@ -22,6 +26,11 @@ struct instruction ...@@ -22,6 +26,11 @@ struct instruction
instruction(operation o, shape r, std::vector<instruction_ref> args); instruction(operation o, shape r, std::vector<instruction_ref> args);
instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules);
instruction(literal l); instruction(literal l);
void replace(operation o); void replace(operation o);
...@@ -32,7 +41,7 @@ struct instruction ...@@ -32,7 +41,7 @@ struct instruction
friend bool operator==(const instruction& i, instruction_ref ref); friend bool operator==(const instruction& i, instruction_ref ref);
bool valid(instruction_ref start) const; bool valid(instruction_ref start, bool check_order = false) const;
bool valid() const; bool valid() const;
...@@ -45,6 +54,8 @@ struct instruction ...@@ -45,6 +54,8 @@ struct instruction
const std::vector<instruction_ref>& inputs() const; const std::vector<instruction_ref>& inputs() const;
const std::vector<module_ref>& module_inputs() const;
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y); friend bool operator==(const instruction& x, const instruction& y);
...@@ -65,13 +76,25 @@ struct instruction ...@@ -65,13 +76,25 @@ struct instruction
migraphx::erase(output, ins); migraphx::erase(output, ins);
} }
static void replace_refs(instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods);
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins); static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static void replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
bool can_eval() const; bool can_eval() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
...@@ -97,18 +120,31 @@ struct instruction ...@@ -97,18 +120,31 @@ struct instruction
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args); void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args);
// internal // internal
void replace(std::vector<instruction_ref> args); void replace(std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args);
// internal // internal
void replace_argument(instruction_ref old, instruction_ref new_ins); void replace_argument(instruction_ref old, instruction_ref new_ins);
// internal
void replace_mod_argument(module_ref old, module_ref new_mod);
void replace(const shape& r); void replace(const shape& r);
operation op; operation op;
shape result{}; shape result{};
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
std::vector<module_ref> module_args;
literal lit; literal lit;
bool normalized = false; bool normalized = false;
}; };
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#include <list> #include <list>
#include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -43,14 +45,19 @@ struct module ...@@ -43,14 +45,19 @@ struct module
std::string name() const; std::string name() const;
template <class... Ts> template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
return add_instruction(op, {args...}); return add_instruction(op, {args...});
} }
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args); instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts> instruction_ref add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{ {
return insert_instruction(ins, op, {args...}); return insert_instruction(ins, op, {args...});
...@@ -58,7 +65,12 @@ struct module ...@@ -58,7 +65,12 @@ struct module
instruction_ref instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args); insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts> instruction_ref insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{ {
return replace_instruction(ins, op, {args...}); return replace_instruction(ins, op, {args...});
...@@ -67,6 +79,11 @@ struct module ...@@ -67,6 +79,11 @@ struct module
const operation& op, const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST; std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instruction(instruction_ref ins);
...@@ -109,21 +126,30 @@ struct module ...@@ -109,21 +126,30 @@ struct module
void finalize(context& ctx); void finalize(context& ctx);
value to_value() const;
void from_value(const value& v);
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void debug_print(const std::vector<instruction_ref>& inss) const;
std::unordered_map<instruction_ref, std::string> print(
const std::function<void(
instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const;
void print(const std::function<void(instruction_ref, void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>& const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const;
module& sort(); module& sort();
friend std::ostream& operator<<(std::ostream& os, const module& m); friend std::ostream& operator<<(std::ostream& os, const module& m);
...@@ -132,6 +158,7 @@ struct module ...@@ -132,6 +158,7 @@ struct module
private: private:
void assign(const module& m); void assign(const module& m);
std::unique_ptr<module_impl> impl; std::unique_ptr<module_impl> impl;
}; };
......
#ifndef MIGRAPHX_GUARD_MODULE_REF_HPP
#define MIGRAPHX_GUARD_MODULE_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct if_op
{
std::string name() const { return "if"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.has(1).standard();
if(mods.size() != 2)
{
MIGRAPHX_THROW("IF: operator should have two submodules.");
}
auto out_shapes0 = mods[0]->get_output_shapes();
auto out_shapes1 = mods[1]->get_output_shapes();
if(not std::equal(
out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end()))
{
MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
}
return out_shapes0.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs) auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<2>,
const T& x, const T& x,
...@@ -278,11 +347,10 @@ void from_value_op(T& x, const value& v) ...@@ -278,11 +347,10 @@ void from_value_op(T& x, const value& v)
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* argument compute(const shape& output,const std::vector<argument>& input) const; * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* value to_value() const; * input) const; argument compute(const shape& output,const std::vector<argument>& input)
* void from_value(const value& v) ; * const; value to_value() const; void from_value(const value& v) ; value attributes() const;
* value attributes() const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ; * friend bool operator==(const operation & x,const operation & y) ;
* }; * };
...@@ -394,6 +462,13 @@ struct operation ...@@ -394,6 +462,13 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(input); return (*this).private_detail_te_get_handle().compute_shape(input);
} }
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args);
}
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -457,6 +532,8 @@ struct operation ...@@ -457,6 +532,8 @@ struct operation
virtual void virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0; finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
...@@ -561,6 +638,25 @@ struct operation ...@@ -561,6 +638,25 @@ struct operation
return detail::normalize_compute_shape_op(private_detail_te_self, input); return detail::normalize_compute_shape_op(private_detail_te_self, input);
} }
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(private_detail_te_self.compute_shape(inputs, mod_args))
{
return private_detail_te_self.compute_shape(inputs, mod_args);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return detail::compute_shape_op(private_detail_te_self, inputs, mod_args);
}
template <class T> template <class T>
static auto private_detail_te_default_compute(char, static auto private_detail_te_default_compute(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -709,6 +805,14 @@ struct operation ...@@ -709,6 +805,14 @@ struct operation
return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input); return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
} }
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const override
{
return private_detail_te_default_compute_shape(
char(0), private_detail_te_value, inputs, mod_args);
}
argument compute(context& ctx, argument compute(context& ctx,
const shape& output, const shape& output,
const std::vector<argument>& input) const override const std::vector<argument>& input) const override
...@@ -841,6 +945,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs) ...@@ -841,6 +945,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
return detail::normalize_compute_shape_op(op, inputs); return detail::normalize_compute_shape_op(op, inputs);
} }
inline shape compute_shape(const operation& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.compute_shape(inputs, mod_args))
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); } inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T> template <class T>
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp> #include <migraphx/op/less.hpp>
......
...@@ -72,8 +72,9 @@ struct program ...@@ -72,8 +72,9 @@ struct program
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void print(const std::function<void(instruction_ref, void print(std::unordered_map<instruction_ref, std::string>& names,
const std::unordered_map<instruction_ref, std::string>&)>& const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
...@@ -89,9 +90,16 @@ struct program ...@@ -89,9 +90,16 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
// module related api
module* create_module(const std::string& name);
module* get_module(const std::string& name);
const module* get_module(const std::string& name) const;
module* get_main_module(); module* get_main_module();
const module* get_main_module() const; const module* get_main_module() const;
std::vector<const module*> get_modules() const;
private: private:
void assign(const program& p); void assign(const program& p);
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
......
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp> #include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,6 +12,17 @@ instruction::instruction(operation o, shape r, std::vector<instruction_ref> args ...@@ -10,6 +12,17 @@ instruction::instruction(operation o, shape r, std::vector<instruction_ref> args
{ {
} }
instruction::instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules)
: op(std::move(o)),
result(std::move(r)),
arguments(std::move(args)),
module_args(std::move(modules))
{
}
instruction::instruction(literal l) instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{ {
...@@ -38,7 +51,7 @@ void instruction::replace(operation o) ...@@ -38,7 +51,7 @@ void instruction::replace(operation o)
recompute_shape(); recompute_shape();
} }
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
void instruction::clear_arguments() void instruction::clear_arguments()
{ {
...@@ -47,6 +60,7 @@ void instruction::clear_arguments() ...@@ -47,6 +60,7 @@ void instruction::clear_arguments()
arg->remove_output(*this); arg->remove_output(*this);
} }
arguments.clear(); arguments.clear();
module_args.clear();
} }
bool operator==(const instruction& i, instruction_ref ref) bool operator==(const instruction& i, instruction_ref ref)
...@@ -54,12 +68,16 @@ bool operator==(const instruction& i, instruction_ref ref) ...@@ -54,12 +68,16 @@ bool operator==(const instruction& i, instruction_ref ref)
return std::addressof(i) == std::addressof(*ref); return std::addressof(i) == std::addressof(*ref);
} }
bool instruction::valid(instruction_ref start) const bool instruction::valid(instruction_ref start, bool check_order) const
{ {
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this); auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() && bool ret = self != i->outputs().end();
std::distance(start, i) < std::distance(start, *self); if(check_order)
{
ret = ret and (std::distance(start, i) < std::distance(start, *self));
}
return ret;
}); });
} }
...@@ -82,7 +100,7 @@ bool instruction::valid() const ...@@ -82,7 +100,7 @@ bool instruction::valid() const
{ {
try try
{ {
computed = compute_shape(op, arguments); computed = compute_shape(op, arguments, module_args);
} }
catch(migraphx::exception&) catch(migraphx::exception&)
{ {
...@@ -90,7 +108,8 @@ bool instruction::valid() const ...@@ -90,7 +108,8 @@ bool instruction::valid() const
} }
} }
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) { return (result == computed) &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end(); return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
}); });
} }
...@@ -108,6 +127,8 @@ std::string instruction::name() const { return op.name(); } ...@@ -108,6 +127,8 @@ std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; } const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; } const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y) bool operator==(const instruction& x, const instruction& y)
...@@ -148,6 +169,13 @@ void instruction::replace_argument(instruction_ref ins, ...@@ -148,6 +169,13 @@ void instruction::replace_argument(instruction_ref ins,
ins->recompute_shape(); ins->recompute_shape();
} }
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
ins->replace_mod_argument(old, new_mod);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins, void instruction::replace(instruction_ref ins,
operation o, operation o,
const shape& r, const shape& r,
...@@ -157,6 +185,16 @@ void instruction::replace(instruction_ref ins, ...@@ -157,6 +185,16 @@ void instruction::replace(instruction_ref ins,
backreference(ins); backreference(ins);
} }
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
ins->replace(std::move(o), r, std::move(args), std::move(module_args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args) void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{ {
normalized = false; normalized = false;
...@@ -165,12 +203,56 @@ void instruction::replace(operation o, const shape& r, std::vector<instruction_r ...@@ -165,12 +203,56 @@ void instruction::replace(operation o, const shape& r, std::vector<instruction_r
replace(std::move(args)); replace(std::move(args));
} }
void instruction::replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args)
{
op = std::move(o);
replace(r);
replace(std::move(args), std::move(mdl_args));
}
void instruction::replace_refs(
instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods)
{
const auto& args = ins->inputs();
for(const auto& arg : args)
{
if(contains(map_insts, arg))
{
instruction::replace_argument(ins, arg, map_insts.at(arg));
}
}
const auto& module_args = ins->module_inputs();
if(module_args.empty())
return;
for(const auto& mod : module_args)
{
if(contains(map_mods, mod))
{
instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
}
}
}
void instruction::replace(std::vector<instruction_ref> args) void instruction::replace(std::vector<instruction_ref> args)
{ {
clear_arguments(); clear_arguments();
arguments = std::move(args); arguments = std::move(args);
} }
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
clear_arguments();
arguments = std::move(args);
module_args = std::move(mdl_args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; })); assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
...@@ -178,6 +260,12 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -178,6 +260,12 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
std::replace(module_args.begin(), module_args.end(), old, new_mod);
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
...@@ -242,12 +330,25 @@ void instruction::print(std::ostream& os, ...@@ -242,12 +330,25 @@ void instruction::print(std::ostream& os,
char delim = '('; char delim = '(';
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
os << delim << names.at(arg); std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
os << delim << arg_name;
delim = ','; delim = ',';
} }
os << ")"; os << ")";
} }
// print module inputs
if(!ins->module_inputs().empty())
{
std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs())
{
os << delim << mod_arg->name();
delim = ", ";
}
os << "]";
}
// skip return instruction shape // skip return instruction shape
if(ins->name() != "@return") if(ins->name() != "@return")
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
...@@ -328,5 +429,18 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg ...@@ -328,5 +429,18 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(to_shapes(args)); return op.compute_shape(to_shapes(args));
} }
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
if(mods.empty())
{
return op.compute_shape(to_shapes(args));
}
else
{
return op.compute_shape(to_shapes(args), mods);
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -88,14 +88,13 @@ void module::assign(const module& m) ...@@ -88,14 +88,13 @@ void module::assign(const module& m)
} }
else else
{ {
// if there are sub_module inputs, need to make a copy of the submodule
auto module_args = ins->module_inputs();
// retrieve its mapped input // retrieve its mapped input
auto inputs = ins->inputs(); auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size()); std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i]; return contains(ins_map, i) ? ins_map[i] : i;
}); });
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
...@@ -103,7 +102,14 @@ void module::assign(const module& m) ...@@ -103,7 +102,14 @@ void module::assign(const module& m)
} }
else else
{ {
copy_ins = add_instruction(ins->get_operator(), copy_inputs); if(module_args.empty())
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
} }
} }
...@@ -119,9 +125,6 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -119,9 +125,6 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)}); auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
...@@ -130,13 +133,32 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -130,13 +133,32 @@ instruction_ref module::insert_instruction(instruction_ref ins,
return result; return result;
} }
instruction_ref module::add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
return insert_instruction(
impl->instructions.end(), op, std::move(args), std::move(module_args));
}
instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result =
impl->instructions.insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{ {
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
...@@ -145,6 +167,18 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -145,6 +167,18 @@ instruction_ref module::replace_instruction(instruction_ref ins,
return ins; return ins;
} }
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep)
{ {
assert(has_instruction(ins)); assert(has_instruction(ins));
...@@ -320,10 +354,18 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const ...@@ -320,10 +354,18 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const
bool module::has_instruction(instruction_ref ins) const bool module::has_instruction(instruction_ref ins) const
{ {
return std::find_if( if(std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x); return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end(); }) != impl->instructions.end())
{
return true;
}
auto parent_modules = get_sub_modules();
return std::any_of(parent_modules.begin(), parent_modules.end(), [&](auto mod) {
return mod->has_instruction(ins);
});
} }
std::size_t module::size() const { return impl->instructions.size(); } std::size_t module::size() const { return impl->instructions.size(); }
...@@ -353,9 +395,15 @@ std::vector<shape> module::get_output_shapes() const ...@@ -353,9 +395,15 @@ std::vector<shape> module::get_output_shapes() const
instruction_ref module::validate() const instruction_ref module::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(
impl->instructions.end(), impl->instructions.begin(), impl->instructions.end(), [&](const instruction& i) {
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); }); auto inputs = i.inputs();
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) {
return contains(impl->instructions, *in);
});
return !i.valid(impl->instructions.begin(), check_order);
});
} }
void module::finalize(context& ctx) void module::finalize(context& ctx)
...@@ -363,7 +411,12 @@ void module::finalize(context& ctx) ...@@ -363,7 +411,12 @@ void module::finalize(context& ctx)
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
ins->finalize(ctx); ins->finalize(ctx);
for(const auto& smod : ins->module_inputs())
{
smod->finalize(ctx);
}
} }
// Warn when an instruction is not normalized // Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); }); auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
if(ins != end()) if(ins != end())
...@@ -371,69 +424,10 @@ void module::finalize(context& ctx) ...@@ -371,69 +424,10 @@ void module::finalize(context& ctx)
<< std::endl; << std::endl;
} }
value module::to_value() const
{
value result;
value nodes;
this->print([&](auto ins, const auto& names) {
value node;
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return names.at(i); });
node["inputs"] = inputs;
nodes.push_back(node);
});
result["nodes"] = nodes;
return result;
}
void module::from_value(const value& v)
{
std::unordered_map<std::string, instruction_ref> instructions;
for(const value& node : v.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = this->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = this->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) { return instructions[i.to<std::string>()]; });
if(name == "@return")
output = this->add_return(inputs);
else
output = this->add_instruction(op, inputs);
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
void module::debug_print() const { std::cout << *this << std::endl; } void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins) const
void module::debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const
{ {
if(ins == this->end()) if(ins == this->end())
{ {
...@@ -446,14 +440,23 @@ void module::debug_print(instruction_ref ins) const ...@@ -446,14 +440,23 @@ void module::debug_print(instruction_ref ins) const
return; return;
} }
std::stringstream ss; std::stringstream ss;
this->print([&](auto x, const auto& names) { this->print(
if(x == ins) [&](auto x, auto ins_names) {
{ if(x == ins)
instruction::print(std::cout, x, names); {
std::cout << std::endl; instruction::print(std::cout, x, ins_names);
} std::cout << std::endl;
}); }
},
names);
} }
void module::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
this->debug_print(ins, names);
}
void module::debug_print(const std::vector<instruction_ref>& inss) const void module::debug_print(const std::vector<instruction_ref>& inss) const
{ {
for(auto ins : inss) for(auto ins : inss)
...@@ -461,13 +464,12 @@ void module::debug_print(const std::vector<instruction_ref>& inss) const ...@@ -461,13 +464,12 @@ void module::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl; std::cout << std::endl;
} }
void module::print(const std::function< std::unordered_map<instruction_ref, std::string> module::print(
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>& const std::function<void(instruction_ref,
print_func) const const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const
{ {
std::unordered_map<instruction_ref, std::string> names;
int count = 0; int count = 0;
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
std::string var_name; std::string var_name;
...@@ -477,18 +479,21 @@ void module::print(const std::function< ...@@ -477,18 +479,21 @@ void module::print(const std::function<
} }
else else
{ {
var_name = "@" + std::to_string(count); var_name = this->name() + ":@" + std::to_string(count);
count++; count++;
} }
names.emplace(ins, var_name); names.emplace(ins, var_name);
assert(std::all_of(ins->inputs().begin(),
ins->inputs().end(),
[&](auto arg) { return this->has_instruction(arg); }) &&
"DEBUG_PRINT: Instruction not found");
print_func(ins, names); print_func(ins, names);
} }
return names;
}
void module::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
{
this->print(print_func, {});
} }
static std::string enclose_name(const std::string& name) static std::string enclose_name(const std::string& name)
...@@ -500,19 +505,20 @@ void module::print_graph(std::ostream& os, bool brief) const ...@@ -500,19 +505,20 @@ void module::print_graph(std::ostream& os, bool brief) const
{ {
os << "digraph {" << std::endl; os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl; os << "\trankdir=LR;" << std::endl;
this->print([&](auto ins, const auto& names) { this->print([&](auto ins, auto ins_names) {
std::string label; std::string label;
if(brief) if(brief)
label = ins->name(); label = ins->name();
else else
label = to_string(ins->get_operator()); label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]"; os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl; os << ";" << std::endl;
if(!ins->inputs().empty()) if(!ins->inputs().empty())
{ {
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins)); os << "\t" << enclose_name(ins_names.at(arg)) << " -> "
<< enclose_name(ins_names.at(ins));
if(not brief) if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]"; os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl; os << ";" << std::endl;
...@@ -571,70 +577,100 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -571,70 +577,100 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}"; os << "}";
} }
void module::print_cpp(std::ostream& os) const std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope // cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = 0;
this->print([&](auto ins, const auto& names) { names = this->print(
auto op = cpp_op_var(names.at(ins), ins); [&](auto ins, auto ins_names) {
if(ins->name().front() != '@') auto op = cpp_op_var(ins_names.at(ins), ins);
{ if(ins->name().front() != '@')
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{ {
os << ", " << cpp_var_name(names.at(input)); os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
} }
os << ");" << std::endl; os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
} if(ins->name() == "@literal")
}); {
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(ins_names.at(input));
}
os << ");" << std::endl;
}
},
names);
return names;
} }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
this->print([&](auto ins, const auto& names) { this->print([&](auto ins, auto ins_names) {
instruction::print(os, ins, names); instruction::print(os, ins, ins_names);
a(ins); a(ins);
os << std::endl; os << std::endl;
}); });
} }
std::vector<module_ref> module::get_sub_modules() const
{
std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this))
{
const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
for(const auto& smod : mod_args)
{
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
}
return vec_modules;
}
module& module::sort() module& module::sort()
{ {
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) for(auto child : ins->inputs())
{
if(!contains(this->impl->instructions, child))
{
continue;
}
self(child); self(child);
}
})(std::prev(this->end())); })(std::prev(this->end()));
assert(this->validate() == this->end()); assert(this->validate() == this->end());
return *this; return *this;
...@@ -644,10 +680,11 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st ...@@ -644,10 +680,11 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st
std::ostream& operator<<(std::ostream& os, const module& m) std::ostream& operator<<(std::ostream& os, const module& m)
{ {
m.print([&](auto ins, const auto& names) { m.print([&](auto ins, auto ins_names) {
instruction::print(os, ins, names); instruction::print(os, ins, ins_names);
os << std::endl; os << std::endl;
}); });
return os; return os;
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -25,12 +26,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -25,12 +26,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program_impl struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::map<std::string, module> modules; // all the modules are store in the depth-first order
std::list<module> modules;
context ctx; context ctx;
std::string target_name; std::string target_name;
}; };
program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {"main"}; } program::program() : impl(std::make_unique<program_impl>()) { impl->modules.push_back({"main"}); }
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
program::~program() noexcept = default; program::~program() noexcept = default;
...@@ -55,9 +57,38 @@ void program::assign(const program& p) ...@@ -55,9 +57,38 @@ void program::assign(const program& p)
{ {
impl->modules.clear(); impl->modules.clear();
} }
impl->ctx = p.impl->ctx; impl->ctx = p.impl->ctx;
impl->target_name = p.impl->target_name; impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules; impl->modules = p.impl->modules;
// build a map from old ins to new ins
// Build a map from old module to new module
std::unordered_map<module_ref, module_ref> mod_map;
std::transform(impl->modules.begin(),
impl->modules.end(),
p.impl->modules.begin(),
std::inserter(mod_map, mod_map.begin()),
[](auto&& x, auto&& y) { return std::make_pair(&y, &x); });
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map)
{
auto old_ins = iterator_for(*pp.first);
auto new_ins = iterator_for(*pp.second);
std::transform(old_ins.begin(),
old_ins.end(),
new_ins.begin(),
std::inserter(ins_map, ins_map.begin()),
[](auto x, auto y) { return std::make_pair(x, y); });
}
// Update all references from all modules
for(auto&& mp : impl->modules)
{
for(auto ins : iterator_for(mp))
instruction::replace_refs(ins, ins_map, mod_map);
}
} }
shape program::get_parameter_shape(std::string name) const shape program::get_parameter_shape(std::string name) const
...@@ -114,28 +145,23 @@ void program::compile(const target& t, compile_options options) ...@@ -114,28 +145,23 @@ void program::compile(const target& t, compile_options options)
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->ctx, options);
for(auto& mp : impl->modules) auto* modl = get_main_module();
assert(modl->validate() == modl->end());
run_passes(*modl, passes, options.trace);
auto invalid = this->validate();
if(invalid != modl->end())
{ {
auto& modl = mp.second; auto index = std::distance(modl->begin(), invalid);
assert(modl.validate() == modl.end()); MIGRAPHX_THROW("Invalid module " + modl->name() + " from compilation at instruction " +
run_passes(modl, passes, options.trace); std::to_string(index));
auto invalid = this->validate();
if(invalid != modl.end())
{
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW("Invalid module " + mp.first + " from compilation at instruction " +
std::to_string(index));
}
modl.finalize(this->impl->ctx);
} }
modl->finalize(this->impl->ctx);
} }
void program::finalize() void program::finalize()
{ {
for(auto& mp : this->impl->modules) auto* mm = this->get_main_module();
{ mm->finalize(this->impl->ctx);
mp.second.finalize(this->impl->ctx);
}
} }
template <class F> template <class F>
...@@ -263,15 +289,128 @@ value program::to_value() const ...@@ -263,15 +289,128 @@ value program::to_value() const
if(not this->impl->target_name.empty()) if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value(); result["context"] = this->impl->ctx.to_value();
result["modules"] = value::object{}; value module_vals = value::array{};
auto& module_val = result.at("modules"); std::unordered_map<instruction_ref, std::string> names;
for(auto& m : impl->modules) for(auto& mod : this->impl->modules)
{ {
module_val[m.first] = m.second.to_value(); value mod_val;
value nodes;
mod_val["name"] = mod.name();
names = mod.print(
[&](auto ins, auto ins_names) {
value node;
node["output"] = ins_names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) {
assert(contains(ins_names, i));
return ins_names.at(i);
});
node["inputs"] = inputs;
auto module_args = ins->module_inputs();
if(not module_args.empty())
{
std::vector<std::string> module_inputs;
std::transform(module_args.begin(),
module_args.end(),
std::back_inserter(module_inputs),
[&](auto mod_ref) { return mod_ref->name(); });
node["module_inputs"] = module_inputs;
}
nodes.push_back(node);
},
names);
mod_val["nodes"] = nodes;
module_vals.push_back(mod_val);
} }
result["modules"] = module_vals;
return result; return result;
} }
static void mod_from_val(module_ref mod,
const value& v,
std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods)
{
const auto* it = std::find_if(v.begin(), v.end(), [&](auto& mv) {
return mv.at("name").template to<std::string>() == mod->name();
});
assert(it != v.end());
const auto& module_val = *it;
for(const value& node : module_val.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = mod->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) {
auto i_name = i.to<std::string>();
assert(contains(instructions, i_name));
return instructions.at(i_name);
});
std::vector<module_ref> module_inputs;
if(node.contains("module_inputs"))
{
std::transform(node.at("module_inputs").begin(),
node.at("module_inputs").end(),
std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs)
{
mod_from_val(smod, v, instructions, map_mods);
}
}
if(name == "@return")
{
output = mod->add_return(inputs);
}
else if(module_inputs.empty())
{
output = mod->add_instruction(op, inputs);
}
else
{
output = mod->add_instruction(op, inputs, module_inputs);
}
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
void program::from_value(const value& v) void program::from_value(const value& v)
{ {
auto version = v.at("version").to<int>(); auto version = v.at("version").to<int>();
...@@ -288,15 +427,21 @@ void program::from_value(const value& v) ...@@ -288,15 +427,21 @@ void program::from_value(const value& v)
this->impl->ctx.from_value(v.at("context")); this->impl->ctx.from_value(v.at("context"));
} }
auto val_modules = v.at("modules"); auto module_vals = v.at("modules");
for(const auto& vv : val_modules) std::unordered_map<std::string, module_ref> map_mods;
for(const auto& vv : module_vals)
{ {
const auto& key = vv.get_key(); const auto& name = vv.at("name").to<std::string>();
auto val = vv.without_key(); if(name == "main")
module modl{key}; continue;
modl.from_value(val); impl->modules.push_back({name});
impl->modules[key] = modl; map_mods[name] = &impl->modules.back();
} }
std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize(); this->finalize();
} }
...@@ -368,8 +513,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -368,8 +513,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
this->print([&](auto ins, auto names) { std::unordered_map<instruction_ref, std::string> names;
instruction::print(std::cout, ins, names); this->print(names, [&](auto ins, auto ins_names) {
instruction::print(std::cout, ins, ins_names);
// skip return instruction // skip return instruction
if(ins->name() == "@return") if(ins->name() == "@return")
...@@ -411,8 +557,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -411,8 +557,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; } void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const void program::debug_print(instruction_ref ins) const
{ {
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.second.end() == ins); return (it.end() == ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
...@@ -420,29 +567,31 @@ void program::debug_print(instruction_ref ins) const ...@@ -420,29 +567,31 @@ void program::debug_print(instruction_ref ins) const
} }
else if(std::none_of(this->impl->modules.begin(), else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(), this->impl->modules.end(),
[&](const auto& it) { return it.second.has_instruction(ins); })) [&](const auto& it) { return it.has_instruction(ins); }))
{ {
std::cout << "Instruction not part of program" << std::endl; std::cout << "Instruction not part of program" << std::endl;
return; return;
} }
std::stringstream ss; std::stringstream ss;
this->print([&](auto x, const auto& names) { this->print(names, [&](auto x, auto ins_names) {
if(x == ins) if(x == ins)
{ {
instruction::print(std::cout, x, names); instruction::print(std::cout, x, ins_names);
std::cout << std::endl; std::cout << std::endl;
} }
}); });
} }
void program::print(const std::function< void program::print(
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>& std::unordered_map<instruction_ref, std::string>& names,
print_func) const const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const
{ {
for(const auto& mdl : this->impl->modules) for(const auto& mod : this->impl->modules)
{ {
mdl.second.print(print_func); std::cout << mod.name() << ":" << std::endl;
mod.print(print_func, names);
} }
} }
...@@ -454,9 +603,14 @@ void program::print_graph(std::ostream& os, bool brief) const ...@@ -454,9 +603,14 @@ void program::print_graph(std::ostream& os, bool brief) const
void program::print_cpp(std::ostream& os) const void program::print_cpp(std::ostream& os) const
{ {
os << "migraphx::program p;" << std::endl; auto vec_modules = this->get_modules();
const auto* mm = this->get_main_module(); std::unordered_map<instruction_ref, std::string> names;
mm->print_cpp(os); for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print_cpp(os, names);
os << std::endl;
}
} }
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
...@@ -467,22 +621,63 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const ...@@ -467,22 +621,63 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{ {
for(auto& modl : this->impl->modules) for(auto& mod : this->impl->modules)
{ {
std::cout << modl.first << ":" << std::endl; std::cout << mod.name() << ":" << std::endl;
modl.second.annotate(os, a); mod.annotate(os, a);
} }
} }
module* program::get_main_module() { return &impl->modules.at("main"); } const module* program::get_module(const std::string& name) const
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
const module* program::get_main_module() const { return &impl->modules.at("main"); } return &(*it);
}
module* program::create_module(const std::string& name)
{
auto it = impl->modules.insert(impl->modules.end(), {name});
return &(*it);
}
module* program::get_module(const std::string& name)
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); }
std::vector<const module*> program::get_modules() const
{
const module* mm = get_main_module();
std::vector<const module*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules;
}
program& program::sort() program& program::sort()
{ {
for(auto& modl : this->impl->modules) for(auto& mod : this->impl->modules)
{ {
modl.second.sort(); mod.sort();
} }
return *this; return *this;
...@@ -492,10 +687,17 @@ bool operator==(const program& x, const program& y) { return to_string(x) == to_ ...@@ -492,10 +687,17 @@ bool operator==(const program& x, const program& y) { return to_string(x) == to_
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
{ {
for(auto& mp : p.impl->modules) auto vec_modules = p.get_modules();
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : vec_modules)
{ {
os << "Module " << mp.first << ": " << std::endl; os << "module: \"" << mod->name() << "\"" << std::endl;
os << mp.second; names = mod->print(
[&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
os << std::endl;
},
names);
os << std::endl; os << std::endl;
} }
......
...@@ -79,6 +79,30 @@ struct pass_op ...@@ -79,6 +79,30 @@ struct pass_op
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct mod_pass_op
{
std::string name() const { return "mod_pass"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs,
std::vector<migraphx::module_ref> mods) const
{
if(!mods.empty())
{
auto out_shapes = mods[0]->get_output_shapes();
return out_shapes[0];
}
if(!inputs.empty())
{
return inputs.front();
}
return {};
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; } int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
......
...@@ -96,6 +96,7 @@ TEST_CASE(module_name) ...@@ -96,6 +96,7 @@ TEST_CASE(module_name)
m3 = m1; m3 = m1;
EXPECT(m3.name() == "name"); EXPECT(m3.name() == "name");
} }
TEST_CASE(module_name_main) TEST_CASE(module_name_main)
{ {
migraphx::program p; migraphx::program p;
...@@ -103,4 +104,103 @@ TEST_CASE(module_name_main) ...@@ -103,4 +104,103 @@ TEST_CASE(module_name_main)
EXPECT(mm->name() == "main"); EXPECT(mm->name() == "main");
} }
TEST_CASE(program_module_assign)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
migraphx::program p1 = p;
EXPECT(p == p1);
}
TEST_CASE(program_module_replace)
{
auto create_program = [](bool use_if) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
migraphx::instruction_ref ret{};
if(use_if)
{
ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
}
else
{
ret = mm->add_instruction(mod_pass_op{}, {cond}, {then_smod, else_smod});
}
mm->add_return({ret});
return p;
};
migraphx::program p1 = create_program(false);
migraphx::program p2 = create_program(true);
EXPECT(p1 != p2);
auto* m1 = p1.get_main_module();
auto ins_pass = std::prev(std::prev(m1->end()));
const auto& inputs = ins_pass->inputs();
const auto& mod_inputs = ins_pass->module_inputs();
m1->replace_instruction(ins_pass, migraphx::make_op("if"), inputs, mod_inputs);
EXPECT(p1 == p2);
}
TEST_CASE(submodule_copy)
{
migraphx::module mm("main");
auto x = mm.add_parameter("x", {migraphx::shape::int64_type});
migraphx::module sm("sub");
sm.add_instruction(migraphx::make_op("sin"), x);
mm.add_instruction(migraphx::make_op("if"), {x}, {&sm, &sm});
auto mm2 = mm;
EXPECT(mm == mm2);
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -27,17 +27,19 @@ TEST_CASE(basic_graph_test) ...@@ -27,17 +27,19 @@ TEST_CASE(basic_graph_test)
std::stringstream ss; std::stringstream ss;
p.print_graph(ss); p.print_graph(ss);
std::string test = ss.str(); std::string test = ss.str();
std::cout << "test = " << test << std::endl;
EXPECT(migraphx::contains(test, "digraph")); EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR")); EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]")); EXPECT(migraphx::contains(test, "\"main:@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]")); EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]")); EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@1\"[label=\"sum\"]")); EXPECT(migraphx::contains(test, "\"main:@1\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@2\"[label=\"sum\"]")); EXPECT(migraphx::contains(test, "\"main:@2\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@1\"")); EXPECT(migraphx::contains(test, "\"x\" -> \"main:@1\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@1\"")); EXPECT(migraphx::contains(test, "\"y\" -> \"main:@1\""));
EXPECT(migraphx::contains(test, "\"@1\" -> \"@2\"")); EXPECT(migraphx::contains(test, "\"main:@1\" -> \"main:@2\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@2\"")); EXPECT(migraphx::contains(test, "\"main:@0\" -> \"main:@2\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]")); EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
} }
......
...@@ -74,4 +74,45 @@ TEST_CASE(unknown_format) ...@@ -74,4 +74,45 @@ TEST_CASE(unknown_format)
EXPECT(test::throws([&] { migraphx::load_buffer(std::vector<char>{}, options); })); EXPECT(test::throws([&] { migraphx::load_buffer(std::vector<char>{}, options); }));
} }
TEST_CASE(program_with_module)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
migraphx::program p1 = p;
auto v = p.to_value();
auto v1 = p1.to_value();
EXPECT(v == v1);
std::stringstream ss;
p.print_cpp(ss);
std::stringstream ss1;
p1.print_cpp(ss1);
EXPECT(ss.str() == ss1.str());
migraphx::program p2;
p2.from_value(v);
EXPECT(p1.sort() == p2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T> template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs) auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<2>,
const T& x, const T& x,
...@@ -292,6 +361,12 @@ void from_value_op(T& x, const value& v) ...@@ -292,6 +361,12 @@ void from_value_op(T& x, const value& v)
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'detail::normalize_compute_shape_op'), default = 'detail::normalize_compute_shape_op'),
virtual('compute_shape',
returns = 'shape',
inputs = 'const std::vector<shape>&',
mod_args = 'const std::vector<module_ref>&',
const = True,
default = 'detail::compute_shape_op'),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
ctx = 'context&', ctx = 'context&',
...@@ -343,6 +418,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs) ...@@ -343,6 +418,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
return detail::normalize_compute_shape_op(op, inputs); return detail::normalize_compute_shape_op(op, inputs);
} }
inline shape compute_shape(const operation& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.compute_shape(inputs, mod_args))
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); } inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment