Unverified Commit 63c5582a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add load/save function for program (#623)



* Add save/load functions

* Formatting

* Add loading and saving to the driver

* Formatting

* Add return

* Serialize the context with the program

* Formatting

* Add python API

* Formatting

* Add c/c++ apis

* Formatting

* Add tests

* Formatting

* Fix tidy error

* Fix python doc

* Restore python code

* Add function name to errors

* Formatting

* Use lvalue for writing

* Serialize context

* Fix convolution and pooling operator for miopen

* Formatting

* Add const ref

* Set target name to gpu

* Add target tests

* Formatting

* Move register target to cpp file

* Fix target test

* Use make_target in driver

* Formatting

* Use make_target for the API

* Formatting

* Add cpu include

* Increase timeout

* Add more tests

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e67aa78c
...@@ -217,3 +217,26 @@ parse_tf ...@@ -217,3 +217,26 @@ parse_tf
:rtype: program :rtype: program
load
----
.. py:function:: load(filename, format='msgpack')
Load a MIGraphX program
:param str filename: Path to file.
:param str format: Format of file. Valid options are msgpack or json.
:rtype: program
save
----
.. py:function:: save(p, filename, format='msgpack')
Save a MIGraphX program
:param program p: Program to save.
:param str filename: Path to file.
:param str format: Format of file. Valid options are msgpack or json.
...@@ -20,6 +20,7 @@ add_library(migraphx ...@@ -20,6 +20,7 @@ add_library(migraphx
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
load_save.cpp
make_op.cpp make_op.cpp
msgpack.cpp msgpack.cpp
program.cpp program.cpp
...@@ -31,6 +32,7 @@ add_library(migraphx ...@@ -31,6 +32,7 @@ add_library(migraphx
serialize.cpp serialize.cpp
pass_manager.cpp pass_manager.cpp
register_op.cpp register_op.cpp
register_target.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
value.cpp value.cpp
......
...@@ -3,14 +3,11 @@ ...@@ -3,14 +3,11 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp>
#ifdef HAVE_GPU #include <migraphx/load_save.hpp>
#include <migraphx/gpu/target.hpp>
#endif
namespace migraphx { namespace migraphx {
...@@ -67,19 +64,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -67,19 +64,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
} }
target get_target(const std::string& name) target get_target(const std::string& name) { return make_target(name); }
{
migraphx::target t;
if(name == "cpu")
t = migraphx::cpu::target();
#ifdef HAVE_GPU
else if(name == "gpu")
t = migraphx::gpu::target();
#endif
else
MIGRAPHX_THROW(migraphx_status_unknown_target, "Unknown target: " + name);
return t;
}
migraphx::compile_options to_compile_options(const migraphx_compile_options& options) migraphx::compile_options to_compile_options(const migraphx_compile_options& options)
{ {
...@@ -88,6 +73,13 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt ...@@ -88,6 +73,13 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
return result; return result;
} }
migraphx::file_options to_file_options(const migraphx_file_options& options)
{
migraphx::file_options result{};
result.format = options.format;
return result;
}
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
{ {
options.default_dim_value = value; options.default_dim_value = value;
...@@ -649,6 +641,15 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr ...@@ -649,6 +641,15 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr
}); });
} }
extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program)
{
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
(program->object).sort();
});
}
extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params) migraphx_program_parameters_t params)
...@@ -674,6 +675,29 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -674,6 +675,29 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
}); });
} }
extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options)
{
return migraphx::try_([&] {
*out = allocate<migraphx_program_t>(migraphx::load(
(name),
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options))));
});
}
extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options)
{
return migraphx::try_([&] {
if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
migraphx::save(
(p->object),
(name),
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options)));
});
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options) extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{ {
return migraphx::try_([&] { destroy((onnx_options)); }); return migraphx::try_([&] { destroy((onnx_options)); });
......
...@@ -44,6 +44,11 @@ typedef struct ...@@ -44,6 +44,11 @@ typedef struct
bool offload_copy; bool offload_copy;
} migraphx_compile_options; } migraphx_compile_options;
typedef struct
{
const char* format;
} migraphx_file_options;
typedef struct migraphx_shape* migraphx_shape_t; typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t; typedef const struct migraphx_shape* const_migraphx_shape_t;
...@@ -179,6 +184,8 @@ migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, ...@@ -179,6 +184,8 @@ migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out,
migraphx_status migraphx_program_print(const_migraphx_program_t program); migraphx_status migraphx_program_print(const_migraphx_program_t program);
migraphx_status migraphx_program_sort(migraphx_program_t program);
migraphx_status migraphx_program_run(migraphx_arguments_t* out, migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
...@@ -186,6 +193,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -186,6 +193,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options);
migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
......
...@@ -494,6 +494,12 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -494,6 +494,12 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
void print() const { call(&migraphx_program_print, this->get_handle_ptr()); } void print() const { call(&migraphx_program_print, this->get_handle_ptr()); }
program sort()
{
call(&migraphx_program_sort, this->get_handle_ptr());
return *this;
}
friend bool operator==(const program& px, const program& py) friend bool operator==(const program& px, const program& py)
{ {
bool pout; bool pout;
...@@ -504,6 +510,26 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -504,6 +510,26 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
}; };
inline program load(const char* filename, migraphx_file_options options)
{
return program(make<migraphx_program>(&migraphx_load, filename, &options), own{});
}
inline program load(const char* filename)
{
return program(make<migraphx_program>(&migraphx_load, filename, nullptr), own{});
}
inline void save(const program& p, const char* filename, migraphx_file_options options)
{
call(&migraphx_save, p.get_handle_ptr(), filename, &options);
}
inline void save(const program& p, const char* filename)
{
call(&migraphx_save, p.get_handle_ptr(), filename, nullptr);
}
struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
......
...@@ -33,6 +33,17 @@ def compile_options_type_wrap(p): ...@@ -33,6 +33,17 @@ def compile_options_type_wrap(p):
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})' p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
@api.cwrap('migraphx::onnx_options') @api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p): def onnx_options_type_wrap(p):
if p.returns: if p.returns:
...@@ -164,6 +175,7 @@ def program(h): ...@@ -164,6 +175,7 @@ def program(h):
invoke='migraphx::get_output_shapes($@)', invoke='migraphx::get_output_shapes($@)',
returns='std::vector<migraphx::shape>') returns='std::vector<migraphx::shape>')
h.method('print', invoke='migraphx::print($@)', const=True) h.method('print', invoke='migraphx::print($@)', const=True)
h.method('sort')
h.method('run', h.method('run',
api.params( api.params(
params='std::unordered_map<std::string, migraphx::argument>'), params='std::unordered_map<std::string, migraphx::argument>'),
...@@ -176,6 +188,19 @@ def program(h): ...@@ -176,6 +188,19 @@ def program(h):
const=True) const=True)
api.add_function('migraphx_load',
api.params(name='const char*',
options='migraphx::file_options'),
fname='migraphx::load',
returns='migraphx::program')
api.add_function('migraphx_save',
api.params(p='migraphx::program&',
name='const char*',
options='migraphx::file_options'),
fname='migraphx::save')
@auto_handle @auto_handle
def onnx_options(h): def onnx_options(h):
h.constructor('create') h.constructor('create')
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/json.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
...@@ -36,6 +38,9 @@ struct loader ...@@ -36,6 +38,9 @@ struct loader
unsigned trim = 0; unsigned trim = 0;
bool optimize = false; bool optimize = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
bool brief = false;
std::string output_type;
std::string output;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -43,6 +48,8 @@ struct loader ...@@ -43,6 +48,8 @@ struct loader
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
ap(file_type, {"--migraphx-json"}, ap.help("Load as MIGraphX JSON"), ap.set_value("json"));
ap(batch, {"--batch"}, ap.help("Set batch size for model")); ap(batch, {"--batch"}, ap.help("Set batch size for model"));
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(skip_unknown_operators, ap(skip_unknown_operators,
...@@ -52,6 +59,25 @@ struct loader ...@@ -52,6 +59,25 @@ struct loader
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(output_type,
{"--graphviz", "-g"},
ap.help("Print out a graphviz representation."),
ap.set_value("graphviz"));
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type,
{"--cpp"},
ap.help("Print out the program as cpp program."),
ap.set_value("cpp"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type,
{"--text"},
ap.help("Print out program in text format."),
ap.set_value("text"));
ap(output_type,
{"--binary"},
ap.help("Print out program in binary format."),
ap.set_value("binary"));
ap(output, {"--output", "-o"}, ap.help("Output to file."));
} }
program load() program load()
...@@ -65,6 +91,10 @@ struct loader ...@@ -65,6 +91,10 @@ struct loader
file_type = "onnx"; file_type = "onnx";
else if(ends_with(file, ".pb")) else if(ends_with(file, ".pb"))
file_type = "tf"; file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -79,6 +109,16 @@ struct loader ...@@ -79,6 +109,16 @@ struct loader
{ {
p = parse_tf(file, tf_options{is_nhwc, batch}); p = parse_tf(file, tf_options{is_nhwc, batch});
} }
else if(file_type == "json")
{
file_options options;
options.format = "json";
p = migraphx::load(file, options);
}
else if(file_type == "migraphx")
{
p = migraphx::load(file);
}
} }
else else
{ {
...@@ -113,6 +153,42 @@ struct loader ...@@ -113,6 +153,42 @@ struct loader
}); });
return p; return p;
} }
static void write(std::ostream& os, const std::vector<char>& buffer)
{
os.write(buffer.data(), buffer.size());
}
void save(const program& p)
{
auto* os = &std::cout;
std::ofstream fs;
if(not output.empty())
{
fs.open(output);
os = &fs;
}
std::string type = output_type;
if(type.empty())
{
if(output.empty())
type = "text";
else
type = "binary";
}
if(type == "cpp")
p.print_cpp(*os);
else if(type == "graphviz")
p.print_graph(*os, brief);
else if(type == "text")
*os << p << std::endl;
else if(type == "json")
*os << to_json_string(p.to_value()) << std::endl;
else if(type == "binary")
write(*os, save_buffer(p));
}
}; };
struct program_params struct program_params
...@@ -171,6 +247,9 @@ struct compiler ...@@ -171,6 +247,9 @@ struct compiler
program compile() program compile()
{ {
auto p = l.load(); auto p = l.load();
// Dont compile if its already been compiled
if(p.is_compiled())
return p;
auto t = get_target(gpu); auto t = get_target(gpu);
if(quantize == q_fp16) if(quantize == q_fp16)
{ {
...@@ -183,6 +262,7 @@ struct compiler ...@@ -183,6 +262,7 @@ struct compiler
compile_options options; compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
p.compile(t, options); p.compile(t, options);
l.save(p);
return p; return p;
} }
}; };
...@@ -190,40 +270,12 @@ struct compiler ...@@ -190,40 +270,12 @@ struct compiler
struct read : command<read> struct read : command<read>
{ {
loader l; loader l;
bool cpp = false; void parse(argument_parser& ap) { l.parse(ap); }
bool graphviz = false;
bool brief = false;
std::string output;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(graphviz,
{"--graphviz", "-g"},
ap.help("Print out a graphviz representation."),
ap.set_value(true));
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(cpp, {"--cpp"}, ap.help("Print out the program as cpp program."), ap.set_value(true));
ap(output, {"--output", "-o"}, ap.help("Output to file."));
}
void run() void run()
{ {
auto p = l.load(); auto p = l.load();
l.save(p);
auto* os = &std::cout;
std::ofstream fs;
if(not output.empty())
{
fs.open(output);
os = &fs;
}
if(cpp)
p.print_cpp(*os);
else if(graphviz)
p.print_graph(*os, brief);
else
*os << p << std::endl;
} }
}; };
...@@ -267,6 +319,7 @@ struct verify : command<verify> ...@@ -267,6 +319,7 @@ struct verify : command<verify>
void run() void run()
{ {
auto p = l.load(); auto p = l.load();
l.save(p);
std::cout << p << std::endl; std::cout << p << std::endl;
compile_options options; compile_options options;
...@@ -297,8 +350,7 @@ struct compile : command<compile> ...@@ -297,8 +350,7 @@ struct compile : command<compile>
void run() void run()
{ {
std::cout << "Compiling ... " << std::endl; std::cout << "Compiling ... " << std::endl;
auto p = c.compile(); c.compile();
std::cout << p << std::endl;
} }
}; };
......
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/cpu/target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -54,34 +53,12 @@ program::parameter_map create_param_map(const program& p, bool gpu) ...@@ -54,34 +53,12 @@ program::parameter_map create_param_map(const program& p, bool gpu)
target get_target(bool gpu) target get_target(bool gpu)
{ {
if(gpu) if(gpu)
{ return make_target("gpu");
#ifdef HAVE_GPU
return gpu::target{};
#else
MIGRAPHX_THROW("Gpu not supported.");
#endif
}
else else
{ return make_target("cpu");
return cpu::target{};
}
} }
void compile_program(program& p, bool gpu) void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
{
if(gpu)
{
#ifdef HAVE_GPU
p.compile(gpu::target{});
#else
MIGRAPHX_THROW("Gpu not supported.");
#endif
}
else
{
p.compile(cpu::target{});
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
#ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP
#include <migraphx/config.hpp>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Action, class T>
int auto_register_action()
{
Action::template apply<T>();
return 0;
}
template <class Action, class T>
struct auto_register
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class Action, class T>
int auto_register<Action, T>::static_register = auto_register_action<Action, T>(); // NOLINT
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
// NOLINTNEXTLINE
#define MIGRAPHX_AUTO_REGISTER(...) \
void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \
migraphx::auto_register<__VA_ARGS__>{}) \
__attribute__((unused));
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct check_context struct check_context
{ {
struct op struct op : auto_register_op<op>
{ {
std::string name() const { return "check_context"; } std::string name() const { return "check_context::" + get_type_name<T>(); }
shape compute_shape(const std::vector<shape>&) const { return {}; } shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{ {
......
...@@ -44,16 +44,17 @@ make_exception(const std::string& context, unsigned int e, const std::string& me ...@@ -44,16 +44,17 @@ make_exception(const std::string& context, unsigned int e, const std::string& me
* *
* @return A string that represents the file location * @return A string that represents the file location
*/ */
inline std::string make_source_context(const std::string& file, int line) inline std::string make_source_context(const std::string& file, int line, const std::string& fname)
{ {
return file + ":" + std::to_string(line); return file + ":" + std::to_string(line) + ": " + fname;
} }
/** /**
* @brief Throw an exception with context information * @brief Throw an exception with context information
*/ */
#define MIGRAPHX_THROW(...) \ #define MIGRAPHX_THROW(...) \
throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__), __VA_ARGS__) throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__, __func__), \
__VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std::string to_json_string(const value& val); std::string to_json_string(const value& val);
value from_json_string(const std::string& str); value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP
#include <migraphx/program.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct file_options
{
std::string format = "msgpack";
};
program load(const std::string& filename, const file_options& options = file_options{});
program load_buffer(const std::vector<char>& buffer, const file_options& options = file_options{});
program
load_buffer(const char* buffer, std::size_t size, const file_options& options = file_options{});
void save(const program& p,
const std::string& filename,
const file_options& options = file_options{});
std::vector<char> save_buffer(const program& p, const file_options& options = file_options{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -9,6 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,6 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> to_msgpack(const value& v); std::vector<char> to_msgpack(const value& v);
value from_msgpack(const std::vector<char>& buffer); value from_msgpack(const std::vector<char>& buffer);
value from_msgpack(const char* buffer, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -114,10 +114,15 @@ struct program ...@@ -114,10 +114,15 @@ struct program
void compile(const target& t, compile_options options = compile_options{}); void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const;
void finalize(); void finalize();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
value to_value() const;
void from_value(const value& v);
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void debug_print(const std::vector<instruction_ref>& inss) const;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
...@@ -14,40 +15,24 @@ operation load_op(const std::string& name); ...@@ -14,40 +15,24 @@ operation load_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
int register_op() void register_op()
{ {
register_op(T{}); register_op(T{});
return 0;
} }
template <class T> struct register_op_action
struct auto_register_op
{ {
static int static_register; template <class T>
// This typedef ensures that the static member will be instantiated if static void apply()
// the class itself is instantiated {
using static_register_type = register_op<T>();
std::integral_constant<decltype(&static_register), &static_register>; }
}; };
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class T> template <class T>
int auto_register_op<T>::static_register = register_op<T>(); // NOLINT using auto_register_op = auto_register<register_op_action, T>;
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#define MIGRAPHX_REGISTER_OP_NAME_DETAIL(x) migraphx_auto_register_##x #define MIGRAPHX_REGISTER_OP(...) MIGRAPHX_AUTO_REGISTER(register_op_action, __VA_ARGS__)
#define MIGRAPHX_REGISTER_OP_NAME(x) MIGRAPHX_REGISTER_OP_NAME_DETAIL(x)
#define MIGRAPHX_REGISTER_OP(...) \
void MIGRAPHX_REGISTER_OP_NAME(__LINE__)(migraphx::auto_register_op<__VA_ARGS__> x = \
migraphx::auto_register_op<__VA_ARGS__>{}) \
__attribute__((unused));
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP
#include <migraphx/config.hpp>
#include <migraphx/target.hpp>
#include <migraphx/auto_register.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_target(const target& t);
target make_target(const std::string& name);
std::vector<std::string> get_targets();
template <class T>
void register_target()
{
register_target(T{});
}
struct register_target_action
{
template <class T>
static void apply()
{
register_target<T>();
}
};
template <class T>
using auto_register_target = auto_register<register_target_action, T>;
#define MIGRAPHX_REGISTER_TARGET(...) MIGRAPHX_AUTO_REGISTER(register_target_action, __VA_ARGS__)
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -141,7 +141,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) -> decltype(x.insert(*x.begi ...@@ -141,7 +141,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) -> decltype(x.insert(*x.begi
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<3>, const value& v, T& x) void from_value_impl(rank<3>, const value& v, T& x)
{ {
reflect_each(x, [&](auto&& y, const std::string& name) { reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>; using type = std::decay_t<decltype(y)>;
y = from_value<type>(v.at(name).without_key()); y = from_value<type>(v.at(name).without_key());
}); });
......
...@@ -120,6 +120,11 @@ std::string to_json_string(const value& val) ...@@ -120,6 +120,11 @@ std::string to_json_string(const value& val)
return j.dump(); return j.dump();
} }
migraphx::value from_json_string(const char* str, std::size_t size)
{
json j = json::parse(str, str + size);
return j.get<value>();
}
migraphx::value from_json_string(const std::string& str) migraphx::value from_json_string(const std::string& str)
{ {
json j = json::parse(str); json j = json::parse(str);
......
#include <migraphx/load_save.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <fstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename)
{
std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg();
is.seekg(0, std::ios::beg);
std::vector<char> buffer(size);
if(!is.read(buffer.data(), size))
{
MIGRAPHX_THROW("Error reading file: " + filename);
}
return buffer;
}
void write_buffer(const std::string& filename, const char* buffer, std::size_t size)
{
std::ofstream os(filename);
os.write(buffer, size);
}
void write_buffer(const std::string& filename, const std::vector<char>& buffer)
{
write_buffer(filename, buffer.data(), buffer.size());
}
program load(const std::string& filename, const file_options& options)
{
return load_buffer(read_buffer(filename), options);
}
program load_buffer(const std::vector<char>& buffer, const file_options& options)
{
return load_buffer(buffer.data(), buffer.size(), options);
}
program load_buffer(const char* buffer, std::size_t size, const file_options& options)
{
program p;
if(options.format == "msgpack")
{
p.from_value(from_msgpack(buffer, size));
}
else if(options.format == "json")
{
p.from_value(from_json_string(buffer, size));
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return p;
}
void save(const program& p, const std::string& filename, const file_options& options)
{
write_buffer(filename, save_buffer(p, options));
}
std::vector<char> save_buffer(const program& p, const file_options& options)
{
value v = p.to_value();
std::vector<char> buffer;
if(options.format == "msgpack")
{
buffer = to_msgpack(v);
}
else if(options.format == "json")
{
std::string s = to_json_string(v);
buffer = std::vector<char>(s.begin(), s.end());
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return buffer;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment