Unverified Commit e96d2b9a authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module operations (#741)



* code backup

* clang format

* code backup

* change the print function to support print instruction from other modules

* clang format

* fix cppcheck error

* fix cppcheck error

* chang to make submodule to be owned by program instead of modules

* clang format

* add an unit test for copy of a program with sub_modules

* clang format

* remove the parent_module member variable from the module class

* clang format

* add unit test for serialization of program with submodules

* clang format

* Fix bug where instructions were not printed when doing TRACE_EVAL

* clang storage of modules from map to list

* clang format

* Formatting

* change the program assign function

* clang format

* code cleanup

* clang format

* backup code

* clang format

* remove unnecessary code

* clang format

* add module print function

* code backup

* refine the module::print function

* refine the module:to_value() function

* code backup

* backup code changes

* code backup

* remove to_value and from_value function from the module class

* rename a function

* rename the if operator

* refine the if operator

* refine the print function of module and program

* code backup

* code backup

* fix a build warning

* fix overload of compute_shape function

* code backup

* fix unit test error

* fix cppcheck error

* fix the issue related to the overload of compute_shape

* fix review comments

* fix cppcheck error

* change the return name of if_op to be if

* clang format

* fix two unit tests

* clang format

* remove the unused compute_op function

* clang format

* fix clang tidy format

* clang format

* enhance the validate function and uncomment a unit test

* clang format

* remove unnecessary code

* clang format

* fix a hang issue related to the valid function

* fix an issue in replace_refs

* clang format

* fix review comments

* clang format

* fix cppcheck error

* add a unit test for more code coverage

* clang format

* fix review comments and add test for more code coverage

* clang format

* fix cppcheck error

* fix a cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix review comments

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7b22d210
......@@ -94,6 +94,7 @@ register_migraphx_ops(
greater
gru
identity
if_op
im2col
leaky_relu
less
......
......@@ -34,19 +34,29 @@ typedef enum {
} migraphx_status;
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum {
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct
{
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math;
} migraphx_compile_options;
/// Options for saving and loading files
typedef struct
{
/// Format to be used for file. It can either be json or msgpack
const char* format;
} migraphx_file_options;
......
......@@ -4,6 +4,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/config.hpp>
......@@ -14,6 +15,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
struct instruction
......@@ -22,6 +26,11 @@ struct instruction
instruction(operation o, shape r, std::vector<instruction_ref> args);
instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules);
instruction(literal l);
void replace(operation o);
......@@ -32,7 +41,7 @@ struct instruction
friend bool operator==(const instruction& i, instruction_ref ref);
bool valid(instruction_ref start) const;
bool valid(instruction_ref start, bool check_order = false) const;
bool valid() const;
......@@ -45,6 +54,8 @@ struct instruction
const std::vector<instruction_ref>& inputs() const;
const std::vector<module_ref>& module_inputs() const;
const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y);
......@@ -65,13 +76,25 @@ struct instruction
migraphx::erase(output, ins);
}
static void replace_refs(instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods);
static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static void replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
bool can_eval() const;
argument eval(bool check_eval = true) const;
......@@ -97,18 +120,31 @@ struct instruction
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args);
// internal
void replace(std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args);
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins);
// internal
void replace_mod_argument(module_ref old, module_ref new_mod);
void replace(const shape& r);
operation op;
shape result{};
std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments;
std::vector<module_ref> module_args;
literal lit;
bool normalized = false;
};
......
......@@ -2,12 +2,14 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#include <list>
#include <unordered_set>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
......@@ -43,14 +45,19 @@ struct module
std::string name() const;
template <class... Ts>
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
......@@ -58,7 +65,12 @@ struct module
instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
......@@ -67,6 +79,11 @@ struct module
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
......@@ -109,21 +126,30 @@ struct module
void finalize(context& ctx);
value to_value() const;
void from_value(const value& v);
void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const;
void debug_print(const std::vector<instruction_ref>& inss) const;
std::unordered_map<instruction_ref, std::string> print(
const std::function<void(
instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const;
void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const;
module& sort();
friend std::ostream& operator<<(std::ostream& os, const module& m);
......@@ -132,6 +158,7 @@ struct module
private:
void assign(const module& m);
std::unique_ptr<module_impl> impl;
};
......
#ifndef MIGRAPHX_GUARD_MODULE_REF_HPP
#define MIGRAPHX_GUARD_MODULE_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct if_op
{
std::string name() const { return "if"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.has(1).standard();
if(mods.size() != 2)
{
MIGRAPHX_THROW("IF: operator should have two submodules.");
}
auto out_shapes0 = mods[0]->get_output_shapes();
auto out_shapes1 = mods[1]->get_output_shapes();
if(not std::equal(
out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end()))
{
MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
}
return out_shapes0.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -11,6 +11,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
......@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs)
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto compute_op(rank<2>,
const T& x,
......@@ -278,11 +347,10 @@ void from_value_op(T& x, const value& v)
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* value to_value() const;
* void from_value(const value& v) ;
* value attributes() const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; value to_value() const; void from_value(const value& v) ; value attributes() const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
......@@ -394,6 +462,13 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(input);
}
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args);
}
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -457,6 +532,8 @@ struct operation
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
......@@ -561,6 +638,25 @@ struct operation
return detail::normalize_compute_shape_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(private_detail_te_self.compute_shape(inputs, mod_args))
{
return private_detail_te_self.compute_shape(inputs, mod_args);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return detail::compute_shape_op(private_detail_te_self, inputs, mod_args);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
......@@ -709,6 +805,14 @@ struct operation
return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
}
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const override
{
return private_detail_te_default_compute_shape(
char(0), private_detail_te_value, inputs, mod_args);
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input) const override
......@@ -841,6 +945,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
return detail::normalize_compute_shape_op(op, inputs);
}
inline shape compute_shape(const operation& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.compute_shape(inputs, mod_args))
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
......
......@@ -38,6 +38,7 @@
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
......
......@@ -72,8 +72,9 @@ struct program
void debug_print() const;
void debug_print(instruction_ref ins) const;
void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>&
void print(std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
......@@ -89,9 +90,16 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
// module related api
module* create_module(const std::string& name);
module* get_module(const std::string& name);
const module* get_module(const std::string& name) const;
module* get_main_module();
const module* get_main_module() const;
std::vector<const module*> get_modules() const;
private:
void assign(const program& p);
std::unique_ptr<program_impl> impl;
......
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,6 +12,17 @@ instruction::instruction(operation o, shape r, std::vector<instruction_ref> args
{
}
instruction::instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules)
: op(std::move(o)),
result(std::move(r)),
arguments(std::move(args)),
module_args(std::move(modules))
{
}
instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
......@@ -38,7 +51,7 @@ void instruction::replace(operation o)
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
void instruction::clear_arguments()
{
......@@ -47,6 +60,7 @@ void instruction::clear_arguments()
arg->remove_output(*this);
}
arguments.clear();
module_args.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
......@@ -54,12 +68,16 @@ bool operator==(const instruction& i, instruction_ref ref)
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start) const
bool instruction::valid(instruction_ref start, bool check_order) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
bool ret = self != i->outputs().end();
if(check_order)
{
ret = ret and (std::distance(start, i) < std::distance(start, *self));
}
return ret;
});
}
......@@ -82,7 +100,7 @@ bool instruction::valid() const
{
try
{
computed = compute_shape(op, arguments);
computed = compute_shape(op, arguments, module_args);
}
catch(migraphx::exception&)
{
......@@ -90,7 +108,8 @@ bool instruction::valid() const
}
}
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return (result == computed) &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
......@@ -108,6 +127,8 @@ std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y)
......@@ -148,6 +169,13 @@ void instruction::replace_argument(instruction_ref ins,
ins->recompute_shape();
}
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
ins->replace_mod_argument(old, new_mod);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
......@@ -157,6 +185,16 @@ void instruction::replace(instruction_ref ins,
backreference(ins);
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
ins->replace(std::move(o), r, std::move(args), std::move(module_args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
normalized = false;
......@@ -165,12 +203,56 @@ void instruction::replace(operation o, const shape& r, std::vector<instruction_r
replace(std::move(args));
}
void instruction::replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args)
{
op = std::move(o);
replace(r);
replace(std::move(args), std::move(mdl_args));
}
void instruction::replace_refs(
instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods)
{
const auto& args = ins->inputs();
for(const auto& arg : args)
{
if(contains(map_insts, arg))
{
instruction::replace_argument(ins, arg, map_insts.at(arg));
}
}
const auto& module_args = ins->module_inputs();
if(module_args.empty())
return;
for(const auto& mod : module_args)
{
if(contains(map_mods, mod))
{
instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
}
}
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
clear_arguments();
arguments = std::move(args);
module_args = std::move(mdl_args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
......@@ -178,6 +260,12 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
std::replace(module_args.begin(), module_args.end(), old, new_mod);
}
bool instruction::can_eval() const
{
if(op.name() == "@literal")
......@@ -242,12 +330,25 @@ void instruction::print(std::ostream& os,
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
os << delim << arg_name;
delim = ',';
}
os << ")";
}
// print module inputs
if(!ins->module_inputs().empty())
{
std::string delim = ", [";
for(auto&& mod_arg : ins->module_inputs())
{
os << delim << mod_arg->name();
delim = ", ";
}
os << "]";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
......@@ -328,5 +429,18 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(to_shapes(args));
}
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
if(mods.empty())
{
return op.compute_shape(to_shapes(args));
}
else
{
return op.compute_shape(to_shapes(args), mods);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -88,14 +88,13 @@ void module::assign(const module& m)
}
else
{
// if there are sub_module inputs, need to make a copy of the submodule
auto module_args = ins->module_inputs();
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
return contains(ins_map, i) ? ins_map[i] : i;
});
if(ins->name() == "@return")
{
......@@ -103,7 +102,14 @@ void module::assign(const module& m)
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
if(module_args.empty())
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
}
}
......@@ -119,9 +125,6 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
......@@ -130,13 +133,32 @@ instruction_ref module::insert_instruction(instruction_ref ins,
return result;
}
instruction_ref module::add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
return insert_instruction(
impl->instructions.end(), op, std::move(args), std::move(module_args));
}
instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result =
impl->instructions.insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
......@@ -145,6 +167,18 @@ instruction_ref module::replace_instruction(instruction_ref ins,
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep)
{
assert(has_instruction(ins));
......@@ -320,10 +354,18 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const
bool module::has_instruction(instruction_ref ins) const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
if(std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end())
{
return true;
}
auto parent_modules = get_sub_modules();
return std::any_of(parent_modules.begin(), parent_modules.end(), [&](auto mod) {
return mod->has_instruction(ins);
});
}
std::size_t module::size() const { return impl->instructions.size(); }
......@@ -353,9 +395,15 @@ std::vector<shape> module::get_output_shapes() const
instruction_ref module::validate() const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& i) {
auto inputs = i.inputs();
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) {
return contains(impl->instructions, *in);
});
return !i.valid(impl->instructions.begin(), check_order);
});
}
void module::finalize(context& ctx)
......@@ -363,7 +411,12 @@ void module::finalize(context& ctx)
for(auto ins : iterator_for(*this))
{
ins->finalize(ctx);
for(const auto& smod : ins->module_inputs())
{
smod->finalize(ctx);
}
}
// Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
if(ins != end())
......@@ -371,69 +424,10 @@ void module::finalize(context& ctx)
<< std::endl;
}
value module::to_value() const
{
value result;
value nodes;
this->print([&](auto ins, const auto& names) {
value node;
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return names.at(i); });
node["inputs"] = inputs;
nodes.push_back(node);
});
result["nodes"] = nodes;
return result;
}
void module::from_value(const value& v)
{
std::unordered_map<std::string, instruction_ref> instructions;
for(const value& node : v.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = this->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = this->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) { return instructions[i.to<std::string>()]; });
if(name == "@return")
output = this->add_return(inputs);
else
output = this->add_instruction(op, inputs);
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins) const
void module::debug_print(instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names) const
{
if(ins == this->end())
{
......@@ -446,14 +440,23 @@ void module::debug_print(instruction_ref ins) const
return;
}
std::stringstream ss;
this->print([&](auto x, const auto& names) {
if(x == ins)
{
instruction::print(std::cout, x, names);
std::cout << std::endl;
}
});
this->print(
[&](auto x, auto ins_names) {
if(x == ins)
{
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
},
names);
}
void module::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
this->debug_print(ins, names);
}
void module::debug_print(const std::vector<instruction_ref>& inss) const
{
for(auto ins : inss)
......@@ -461,13 +464,12 @@ void module::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl;
}
void module::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
std::unordered_map<instruction_ref, std::string> module::print(
const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(*this))
{
std::string var_name;
......@@ -477,18 +479,21 @@ void module::print(const std::function<
}
else
{
var_name = "@" + std::to_string(count);
var_name = this->name() + ":@" + std::to_string(count);
count++;
}
names.emplace(ins, var_name);
assert(std::all_of(ins->inputs().begin(),
ins->inputs().end(),
[&](auto arg) { return this->has_instruction(arg); }) &&
"DEBUG_PRINT: Instruction not found");
print_func(ins, names);
}
return names;
}
void module::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
{
this->print(print_func, {});
}
static std::string enclose_name(const std::string& name)
......@@ -500,19 +505,20 @@ void module::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
this->print([&](auto ins, const auto& names) {
this->print([&](auto ins, auto ins_names) {
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
os << "\t" << enclose_name(ins_names.at(arg)) << " -> "
<< enclose_name(ins_names.at(ins));
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
......@@ -571,70 +577,100 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}";
}
void module::print_cpp(std::ostream& os) const
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{
os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0;
this->print([&](auto ins, const auto& names) {
auto op = cpp_op_var(names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
names = this->print(
[&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins);
if(ins->name().front() != '@')
{
os << ", " << cpp_var_name(names.at(input));
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << ");" << std::endl;
}
});
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(ins_names.at(input));
}
os << ");" << std::endl;
}
},
names);
return names;
}
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
this->print([&](auto ins, const auto& names) {
instruction::print(os, ins, names);
this->print([&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
a(ins);
os << std::endl;
});
}
std::vector<module_ref> module::get_sub_modules() const
{
std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this))
{
const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
for(const auto& smod : mod_args)
{
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
}
return vec_modules;
}
module& module::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
{
if(!contains(this->impl->instructions, child))
{
continue;
}
self(child);
}
})(std::prev(this->end()));
assert(this->validate() == this->end());
return *this;
......@@ -644,10 +680,11 @@ bool operator==(const module& x, const module& y) { return to_string(x) == to_st
std::ostream& operator<<(std::ostream& os, const module& m)
{
m.print([&](auto ins, const auto& names) {
instruction::print(os, ins, names);
m.print([&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
os << std::endl;
});
return os;
}
......
......@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -25,12 +26,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program_impl
{
// A map is used to keep references to modules of the program
std::map<std::string, module> modules;
// all the modules are store in the depth-first order
std::list<module> modules;
context ctx;
std::string target_name;
};
program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {"main"}; }
program::program() : impl(std::make_unique<program_impl>()) { impl->modules.push_back({"main"}); }
program::program(program&&) noexcept = default;
program::~program() noexcept = default;
......@@ -55,9 +57,38 @@ void program::assign(const program& p)
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
// build a map from old ins to new ins
// Build a map from old module to new module
std::unordered_map<module_ref, module_ref> mod_map;
std::transform(impl->modules.begin(),
impl->modules.end(),
p.impl->modules.begin(),
std::inserter(mod_map, mod_map.begin()),
[](auto&& x, auto&& y) { return std::make_pair(&y, &x); });
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map)
{
auto old_ins = iterator_for(*pp.first);
auto new_ins = iterator_for(*pp.second);
std::transform(old_ins.begin(),
old_ins.end(),
new_ins.begin(),
std::inserter(ins_map, ins_map.begin()),
[](auto x, auto y) { return std::make_pair(x, y); });
}
// Update all references from all modules
for(auto&& mp : impl->modules)
{
for(auto ins : iterator_for(mp))
instruction::replace_refs(ins, ins_map, mod_map);
}
}
shape program::get_parameter_shape(std::string name) const
......@@ -114,28 +145,23 @@ void program::compile(const target& t, compile_options options)
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
for(auto& mp : impl->modules)
auto* modl = get_main_module();
assert(modl->validate() == modl->end());
run_passes(*modl, passes, options.trace);
auto invalid = this->validate();
if(invalid != modl->end())
{
auto& modl = mp.second;
assert(modl.validate() == modl.end());
run_passes(modl, passes, options.trace);
auto invalid = this->validate();
if(invalid != modl.end())
{
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW("Invalid module " + mp.first + " from compilation at instruction " +
std::to_string(index));
}
modl.finalize(this->impl->ctx);
auto index = std::distance(modl->begin(), invalid);
MIGRAPHX_THROW("Invalid module " + modl->name() + " from compilation at instruction " +
std::to_string(index));
}
modl->finalize(this->impl->ctx);
}
void program::finalize()
{
for(auto& mp : this->impl->modules)
{
mp.second.finalize(this->impl->ctx);
}
auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx);
}
template <class F>
......@@ -263,15 +289,128 @@ value program::to_value() const
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
result["modules"] = value::object{};
auto& module_val = result.at("modules");
for(auto& m : impl->modules)
value module_vals = value::array{};
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->impl->modules)
{
module_val[m.first] = m.second.to_value();
value mod_val;
value nodes;
mod_val["name"] = mod.name();
names = mod.print(
[&](auto ins, auto ins_names) {
value node;
node["output"] = ins_names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) {
assert(contains(ins_names, i));
return ins_names.at(i);
});
node["inputs"] = inputs;
auto module_args = ins->module_inputs();
if(not module_args.empty())
{
std::vector<std::string> module_inputs;
std::transform(module_args.begin(),
module_args.end(),
std::back_inserter(module_inputs),
[&](auto mod_ref) { return mod_ref->name(); });
node["module_inputs"] = module_inputs;
}
nodes.push_back(node);
},
names);
mod_val["nodes"] = nodes;
module_vals.push_back(mod_val);
}
result["modules"] = module_vals;
return result;
}
static void mod_from_val(module_ref mod,
const value& v,
std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods)
{
const auto* it = std::find_if(v.begin(), v.end(), [&](auto& mv) {
return mv.at("name").template to<std::string>() == mod->name();
});
assert(it != v.end());
const auto& module_val = *it;
for(const value& node : module_val.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = mod->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) {
auto i_name = i.to<std::string>();
assert(contains(instructions, i_name));
return instructions.at(i_name);
});
std::vector<module_ref> module_inputs;
if(node.contains("module_inputs"))
{
std::transform(node.at("module_inputs").begin(),
node.at("module_inputs").end(),
std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs)
{
mod_from_val(smod, v, instructions, map_mods);
}
}
if(name == "@return")
{
output = mod->add_return(inputs);
}
else if(module_inputs.empty())
{
output = mod->add_instruction(op, inputs);
}
else
{
output = mod->add_instruction(op, inputs, module_inputs);
}
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
void program::from_value(const value& v)
{
auto version = v.at("version").to<int>();
......@@ -288,15 +427,21 @@ void program::from_value(const value& v)
this->impl->ctx.from_value(v.at("context"));
}
auto val_modules = v.at("modules");
for(const auto& vv : val_modules)
auto module_vals = v.at("modules");
std::unordered_map<std::string, module_ref> map_mods;
for(const auto& vv : module_vals)
{
const auto& key = vv.get_key();
auto val = vv.without_key();
module modl{key};
modl.from_value(val);
impl->modules[key] = modl;
const auto& name = vv.at("name").to<std::string>();
if(name == "main")
continue;
impl->modules.push_back({name});
map_mods[name] = &impl->modules.back();
}
std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize();
}
......@@ -368,8 +513,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
this->print([&](auto ins, auto names) {
instruction::print(std::cout, ins, names);
std::unordered_map<instruction_ref, std::string> names;
this->print(names, [&](auto ins, auto ins_names) {
instruction::print(std::cout, ins, ins_names);
// skip return instruction
if(ins->name() == "@return")
......@@ -411,8 +557,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.second.end() == ins);
return (it.end() == ins);
}))
{
std::cout << "End instruction" << std::endl;
......@@ -420,29 +567,31 @@ void program::debug_print(instruction_ref ins) const
}
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& it) { return it.second.has_instruction(ins); }))
[&](const auto& it) { return it.has_instruction(ins); }))
{
std::cout << "Instruction not part of program" << std::endl;
return;
}
std::stringstream ss;
this->print([&](auto x, const auto& names) {
this->print(names, [&](auto x, auto ins_names) {
if(x == ins)
{
instruction::print(std::cout, x, names);
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
});
}
void program::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
void program::print(
std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const
{
for(const auto& mdl : this->impl->modules)
for(const auto& mod : this->impl->modules)
{
mdl.second.print(print_func);
std::cout << mod.name() << ":" << std::endl;
mod.print(print_func, names);
}
}
......@@ -454,9 +603,14 @@ void program::print_graph(std::ostream& os, bool brief) const
void program::print_cpp(std::ostream& os) const
{
os << "migraphx::program p;" << std::endl;
const auto* mm = this->get_main_module();
mm->print_cpp(os);
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print_cpp(os, names);
os << std::endl;
}
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
......@@ -467,22 +621,63 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{
for(auto& modl : this->impl->modules)
for(auto& mod : this->impl->modules)
{
std::cout << modl.first << ":" << std::endl;
modl.second.annotate(os, a);
std::cout << mod.name() << ":" << std::endl;
mod.annotate(os, a);
}
}
module* program::get_main_module() { return &impl->modules.at("main"); }
const module* program::get_module(const std::string& name) const
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
const module* program::get_main_module() const { return &impl->modules.at("main"); }
return &(*it);
}
module* program::create_module(const std::string& name)
{
auto it = impl->modules.insert(impl->modules.end(), {name});
return &(*it);
}
module* program::get_module(const std::string& name)
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); }
std::vector<const module*> program::get_modules() const
{
const module* mm = get_main_module();
std::vector<const module*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules;
}
program& program::sort()
{
for(auto& modl : this->impl->modules)
for(auto& mod : this->impl->modules)
{
modl.second.sort();
mod.sort();
}
return *this;
......@@ -492,10 +687,17 @@ bool operator==(const program& x, const program& y) { return to_string(x) == to_
std::ostream& operator<<(std::ostream& os, const program& p)
{
for(auto& mp : p.impl->modules)
auto vec_modules = p.get_modules();
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : vec_modules)
{
os << "Module " << mp.first << ": " << std::endl;
os << mp.second;
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print(
[&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
os << std::endl;
},
names);
os << std::endl;
}
......
......@@ -79,6 +79,30 @@ struct pass_op
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct mod_pass_op
{
std::string name() const { return "mod_pass"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs,
std::vector<migraphx::module_ref> mods) const
{
if(!mods.empty())
{
auto out_shapes = mods[0]->get_output_shapes();
return out_shapes[0];
}
if(!inputs.empty())
{
return inputs.front();
}
return {};
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
......
......@@ -96,6 +96,7 @@ TEST_CASE(module_name)
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
......@@ -103,4 +104,103 @@ TEST_CASE(module_name_main)
EXPECT(mm->name() == "main");
}
TEST_CASE(program_module_assign)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
migraphx::program p1 = p;
EXPECT(p == p1);
}
TEST_CASE(program_module_replace)
{
auto create_program = [](bool use_if) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
migraphx::instruction_ref ret{};
if(use_if)
{
ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
}
else
{
ret = mm->add_instruction(mod_pass_op{}, {cond}, {then_smod, else_smod});
}
mm->add_return({ret});
return p;
};
migraphx::program p1 = create_program(false);
migraphx::program p2 = create_program(true);
EXPECT(p1 != p2);
auto* m1 = p1.get_main_module();
auto ins_pass = std::prev(std::prev(m1->end()));
const auto& inputs = ins_pass->inputs();
const auto& mod_inputs = ins_pass->module_inputs();
m1->replace_instruction(ins_pass, migraphx::make_op("if"), inputs, mod_inputs);
EXPECT(p1 == p2);
}
TEST_CASE(submodule_copy)
{
migraphx::module mm("main");
auto x = mm.add_parameter("x", {migraphx::shape::int64_type});
migraphx::module sm("sub");
sm.add_instruction(migraphx::make_op("sin"), x);
mm.add_instruction(migraphx::make_op("if"), {x}, {&sm, &sm});
auto mm2 = mm;
EXPECT(mm == mm2);
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -27,17 +27,19 @@ TEST_CASE(basic_graph_test)
std::stringstream ss;
p.print_graph(ss);
std::string test = ss.str();
std::cout << "test = " << test << std::endl;
EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"main:@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@1\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@2\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"@1\" -> \"@2\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@2\""));
EXPECT(migraphx::contains(test, "\"main:@1\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"main:@2\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"main:@1\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"main:@1\""));
EXPECT(migraphx::contains(test, "\"main:@1\" -> \"main:@2\""));
EXPECT(migraphx::contains(test, "\"main:@0\" -> \"main:@2\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
......
......@@ -74,4 +74,45 @@ TEST_CASE(unknown_format)
EXPECT(test::throws([&] { migraphx::load_buffer(std::vector<char>{}, options); }));
}
TEST_CASE(program_with_module)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
migraphx::program p1 = p;
auto v = p.to_value();
auto v1 = p1.to_value();
EXPECT(v == v1);
std::stringstream ss;
p.print_cpp(ss);
std::stringstream ss1;
p1.print_cpp(ss1);
EXPECT(ss.str() == ss1.str());
migraphx::program p2;
p2.from_value(v);
EXPECT(p1.sort() == p2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -11,6 +11,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
......@@ -100,13 +101,81 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs)
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
auto compute_op(rank<2>,
const T& x,
......@@ -292,6 +361,12 @@ void from_value_op(T& x, const value& v)
input = 'const std::vector<shape>&',
const = True,
default = 'detail::normalize_compute_shape_op'),
virtual('compute_shape',
returns = 'shape',
inputs = 'const std::vector<shape>&',
mod_args = 'const std::vector<module_ref>&',
const = True,
default = 'detail::compute_shape_op'),
virtual('compute',
returns = 'argument',
ctx = 'context&',
......@@ -343,6 +418,31 @@ inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
return detail::normalize_compute_shape_op(op, inputs);
}
inline shape compute_shape(const operation& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.compute_shape(inputs, mod_args))
{
return op.compute_shape(inputs, mod_args);
}
template <class T>
inline auto compute_shape(const T& op,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment