Unverified Commit 2074d756 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Parse onnx using explicit input shape (#474)



* refine slice implementation

* clang format

* fix cppcheck error

* clang format

* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* fix a bug in parsing slice

* change parsing pad to support opset 11 definition

* clang format

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* fix cppcheck error

* clang format

* refine an error message

* add unit tests for the pad operator

* clang format

* add unit tests for slice operator

* clang format

* fix review comments

* fix cppcheck error

* add the numpy package in the Dockerfile

* add c api for onnx options

* clang format

* change c/c++ and python apis related to parse_onnx option change

* clang format

* fixed a bug

* fix a bug

* add fix bugs in cpp api

* fix bugs in c api for onnx_options

* clang format

* add unit tests

* clang format

* add missing onnx file

* add more unit test for dynamic input shape support

* clang format

* fix cppcheck error

* fix cppcheck error

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* code change to resolve the segmentation problem

* clang format

* change the api

* clang format

* fixed a unit test error

* fix cppcheck error

* clang format

* fixed a cppcheck error

* fix review comments

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1039011a
......@@ -87,11 +87,16 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
return result;
}
migraphx::onnx_options to_onnx_options(const migraphx_onnx_options& options)
void set_default_dim_value(onnx_options& options, size_t value)
{
migraphx::onnx_options result{};
result.batch_size = options.batch_size;
return result;
options.default_dim_value = value;
}
void set_input_parameter_shape(onnx_options& options,
const char* name,
const std::vector<std::size_t>& dims)
{
options.map_input_dims[std::string(name)] = dims;
}
template <class Value>
......@@ -223,6 +228,16 @@ struct migraphx_program
migraphx::program object;
};
extern "C" struct migraphx_onnx_options;
struct migraphx_onnx_options
{
template <class... Ts>
migraphx_onnx_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::onnx_options object;
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
......@@ -574,25 +589,60 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
});
}
extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options)
{
return migraphx::try_([&] { destroy((onnx_options)); });
}
extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options)
{
return migraphx::try_([&] {
*onnx_options = object_cast<migraphx_onnx_options_t>(allocate<migraphx::onnx_options>());
});
}
extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size)
{
return migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_input_parameter_shape(
(onnx_options->object), (name), (std::vector<size_t>(dims, dims + dims_size)));
});
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{
return migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_dim_value((onnx_options->object), (value));
});
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options* options)
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
*out = allocate<migraphx_program_t>(migraphx::parse_onnx(
(name),
(options == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*options))));
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(migraphx::parse_onnx((name), (options->object)));
});
}
extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data,
size_t size,
migraphx_onnx_options* options)
migraphx_onnx_options_t options)
{
return migraphx::try_([&] {
*out = allocate<migraphx_program_t>(migraphx::parse_onnx_buffer(
(data),
(size),
(options == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*options))));
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
*out = allocate<migraphx_program_t>(
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
}
......@@ -43,11 +43,6 @@ typedef struct
bool offload_copy;
} migraphx_compile_options;
typedef struct
{
size_t batch_size;
} migraphx_onnx_options;
typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t;
......@@ -72,6 +67,9 @@ typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -171,13 +169,23 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value);
migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options* options);
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
const void* data,
size_t size,
migraphx_onnx_options* options);
migraphx_onnx_options_t options);
#ifdef __cplusplus
}
......
......@@ -485,38 +485,70 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
inline program parse_onnx(const char* filename, migraphx_onnx_options options)
struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, &options), own{});
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); }
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
{
call(&migraphx_onnx_options_set_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dim.data(),
dim.size());
}
void set_default_dim_value(unsigned int value)
{
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
}
};
inline program parse_onnx(const char* filename, const migraphx::onnx_options& options)
{
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, options.get_handle_ptr()),
own{});
}
inline program parse_onnx(const char* filename)
{
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, nullptr), own{});
migraphx::onnx_options options;
return program(make<migraphx_program>(&migraphx_parse_onnx, filename, options.get_handle_ptr()),
own{});
}
inline program parse_onnx_buffer(const void* data, size_t size, migraphx_onnx_options options)
inline program
parse_onnx_buffer(const void* data, size_t size, const migraphx::onnx_options& options)
{
return program(make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, &options),
own{});
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()),
own{});
}
inline program parse_onnx_buffer(const void* data, size_t size)
{
return program(make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, nullptr), own{});
migraphx::onnx_options options;
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()),
own{});
}
inline program parse_onnx_buffer(const std::string& buffer, migraphx_onnx_options options)
inline program parse_onnx_buffer(const std::string& buffer, const migraphx::onnx_options& options)
{
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), &options),
make<migraphx_program>(
&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()),
own{});
}
inline program parse_onnx_buffer(const std::string& buffer)
{
migraphx::onnx_options options;
return program(
make<migraphx_program>(&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), nullptr),
make<migraphx_program>(
&migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()),
own{});
}
......
......@@ -170,6 +170,21 @@ def program(h):
const=True)
@auto_handle
def onnx_options(h):
h.constructor('create')
h.method(
'set_input_parameter_shape',
api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
'set_default_dim_value',
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
api.add_function('migraphx_parse_onnx',
api.params(name='const char*',
options='migraphx::onnx_options'),
......
......@@ -10,17 +10,18 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
unsigned int batch_size = 1;
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
};
/// Create a program from an onnx file
program parse_onnx(const std::string& name, onnx_options = onnx_options{});
program parse_onnx(const std::string& name, const onnx_options& = onnx_options{});
/// Create a program from an onnx buffer
program parse_onnx_buffer(const std::string& buffer, onnx_options options);
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options);
/// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options);
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -34,9 +34,10 @@ struct onnx_parser
std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
bool is_pytorch = false;
unsigned int batch_size = 1;
program prog = program();
bool is_pytorch = false;
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
......@@ -1927,8 +1928,13 @@ struct onnx_parser
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
{
// TODO: Get shape of input parameter
shape s = parse_type(input.type(), batch_size);
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
instructions[name] = prog.add_parameter(name, s);
}
}
......@@ -2118,7 +2124,7 @@ struct onnx_parser
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
{
shape::type_t shape_type{};
switch(t.tensor_type().elem_type())
......@@ -2141,6 +2147,12 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type");
}
if(!input_dims.empty())
{
return {shape_type, input_dims};
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
......@@ -2150,11 +2162,17 @@ struct onnx_parser
if(d.has_dim_value())
{
if(static_cast<int>(d.dim_value()) <= 0)
return batch_size;
{
return default_dim_value;
}
return d.dim_value();
}
return batch_size;
else
{
return default_dim_value;
}
});
if(dims.empty())
return {shape_type};
......@@ -2193,10 +2211,12 @@ struct onnx_parser
};
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
onnx_parser parser;
parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
......@@ -2214,18 +2234,18 @@ program parse_onnx_from(onnx_options options, Ts&&... xs)
return std::move(parser.prog);
}
program parse_onnx(const std::string& name, onnx_options options)
program parse_onnx(const std::string& name, const onnx_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
return parse_onnx_from(options, input);
}
program parse_onnx_buffer(const std::string& buffer, onnx_options options)
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
{
return parse_onnx_from(options, buffer.data(), buffer.size());
}
program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
{
return parse_onnx_from(options, data, size);
}
......
......@@ -180,13 +180,20 @@ PYBIND11_MODULE(migraphx, m)
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1);
m.def("parse_onnx",
[](const std::string& filename, unsigned int batch_size) {
return migraphx::parse_onnx(filename, migraphx::onnx_options{batch_size});
[](const std::string& filename,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::size_t value) {
migraphx::onnx_options options;
options.map_input_dims = map_input_dims;
options.default_dim_value = value;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("batch_size") = 1);
py::arg("map_input_dims") = std::map<std::string, std::vector<std::size_t>>(),
py::arg("value") = 1);
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
......
......@@ -22,6 +22,28 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(load_and_run_user_input_shape)
{
migraphx::onnx_options options;
options.set_input_parameter_shape("0", {2, 3, 64, 64});
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx", options);
auto shapes_before = p.get_output_shapes();
p.compile(migraphx::target("cpu"));
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1);
CHECK(shapes_before.size() == shapes_after.size());
CHECK(bool{shapes_before.front() == shapes_after.front()});
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto outputs = p.eval(pp);
CHECK(shapes_before.size() == outputs.size());
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(zero_parameter)
{
auto p = migraphx::parse_onnx("constant_fill_test.onnx");
......
......@@ -814,6 +814,23 @@ TEST_CASE(implicit_add_bcast_test)
EXPECT(p == prog);
}
TEST_CASE(implicit_add_bcast_user_input_shape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}});
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{3, 4, 5, 6}}, l1);
auto r = p.add_instruction(migraphx::op::add{}, l0, l3);
p.add_return({r});
migraphx::onnx_options options;
options.map_input_dims["0"] = {3, 4, 5, 6};
options.map_input_dims["1"] = {4, 5, 1};
auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
......@@ -1684,6 +1701,21 @@ TEST_CASE(variable_batch_test)
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = p.add_instruction(migraphx::op::identity{}, l0);
p.add_return({r});
migraphx::onnx_options options;
options.default_dim_value = 2;
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_leq_zero_test)
{
migraphx::program p;
......
......@@ -87,11 +87,16 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
return result;
}
migraphx::onnx_options to_onnx_options(const migraphx_onnx_options& options)
void set_default_dim_value(onnx_options& options, size_t value)
{
migraphx::onnx_options result{};
result.batch_size = options.batch_size;
return result;
options.default_dim_value = value;
}
void set_input_parameter_shape(onnx_options& options,
const char* name,
const std::vector<std::size_t>& dims)
{
options.map_input_dims[std::string(name)] = dims;
}
template <class Value>
......
......@@ -43,11 +43,6 @@ typedef struct
bool offload_copy;
} migraphx_compile_options;
typedef struct
{
size_t batch_size;
} migraphx_onnx_options;
<% generate_c_header() %>
#ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment