Unverified Commit 63c5582a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add load/save function for program (#623)



* Add save/load functions

* Formatting

* Add loading and saving to the driver

* Formatting

* Add return

* Serialize the context with the program

* Formatting

* Add python API

* Formatting

* Add c/c++ apis

* Formatting

* Add tests

* Formatting

* Fix tidy error

* Fix python doc

* Restore python code

* Add function name to errors

* Formatting

* Use lvalue for writing

* Serialize context

* Fix convolution and pooling operator for miopen

* Formatting

* Add const ref

* Set target name to gpu

* Add target tests

* Formatting

* Move register target to cpp file

* Fix target test

* Use make_target in driver

* Formatting

* Use make_target for the API

* Formatting

* Add cpu include

* Increase timeout

* Add more tests

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e67aa78c
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/op/add.hpp>
#include "test.hpp"
#include <cstdio>
migraphx::program create_program()
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2);
auto add = p.add_instruction(migraphx::op::add{}, x, two);
p.add_return({add});
return p;
}
TEST_CASE(as_value)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2.from_value(p1.to_value());
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(as_msgpack)
{
migraphx::file_options options;
options.format = "msgpack";
migraphx::program p1 = create_program();
std::vector<char> buffer = migraphx::save_buffer(p1, options);
migraphx::program p2 = migraphx::load_buffer(buffer, options);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(as_json)
{
migraphx::file_options options;
options.format = "json";
migraphx::program p1 = create_program();
std::vector<char> buffer = migraphx::save_buffer(p1, options);
migraphx::program p2 = migraphx::load_buffer(buffer, options);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(as_file)
{
std::string filename = "migraphx_program.dat";
migraphx::program p1 = create_program();
migraphx::save(p1, filename);
migraphx::program p2 = migraphx::load(filename);
std::remove(filename.c_str());
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(compiled)
{
migraphx::program p1 = create_program();
p1.compile(migraphx::cpu::target{});
std::vector<char> buffer = migraphx::save_buffer(p1);
migraphx::program p2 = migraphx::load_buffer(buffer);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(unknown_format)
{
migraphx::file_options options;
options.format = "???";
EXPECT(test::throws([&] { migraphx::save_buffer(create_program(), options); }));
EXPECT(test::throws([&] { migraphx::load_buffer(std::vector<char>{}, options); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/register_target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/target.hpp>
#include "test.hpp"
TEST_CASE(make_target)
{
for(const auto& name : migraphx::get_targets())
{
auto t = migraphx::make_target(name);
CHECK(t.name() == name);
}
}
TEST_CASE(targets)
{
auto ts = migraphx::get_targets();
EXPECT(ts.size() > 0);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3,14 +3,11 @@ ...@@ -3,14 +3,11 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp>
#ifdef HAVE_GPU #include <migraphx/load_save.hpp>
#include <migraphx/gpu/target.hpp>
#endif
namespace migraphx { namespace migraphx {
...@@ -67,19 +64,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -67,19 +64,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
} }
target get_target(const std::string& name) target get_target(const std::string& name) { return make_target(name); }
{
migraphx::target t;
if(name == "cpu")
t = migraphx::cpu::target();
#ifdef HAVE_GPU
else if(name == "gpu")
t = migraphx::gpu::target();
#endif
else
MIGRAPHX_THROW(migraphx_status_unknown_target, "Unknown target: " + name);
return t;
}
migraphx::compile_options to_compile_options(const migraphx_compile_options& options) migraphx::compile_options to_compile_options(const migraphx_compile_options& options)
{ {
...@@ -88,6 +73,13 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt ...@@ -88,6 +73,13 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
return result; return result;
} }
migraphx::file_options to_file_options(const migraphx_file_options& options)
{
migraphx::file_options result{};
result.format = options.format;
return result;
}
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
{ {
options.default_dim_value = value; options.default_dim_value = value;
......
...@@ -44,6 +44,11 @@ typedef struct ...@@ -44,6 +44,11 @@ typedef struct
bool offload_copy; bool offload_copy;
} migraphx_compile_options; } migraphx_compile_options;
typedef struct
{
const char* format;
} migraphx_file_options;
<% generate_c_header() %> <% generate_c_header() %>
#ifdef __cplusplus #ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment