Unverified Commit de10423f authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add external data format support (#663)



* initial progress

* formatting

* change function def

* move read_buffer to header

* formatting

* add test files

* formatting

* fix tidy and deepcode errors

* deepcode check

* use file_buffer

* add const

* use newer fs calls

* formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent d39e51ed
......@@ -2,6 +2,7 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream>
namespace migraphx {
......
......@@ -19,6 +19,8 @@
#include <migraphx/pad_calc.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
......@@ -52,6 +54,8 @@ namespace onnx = onnx_for_migraphx;
struct onnx_parser
{
std::string filename;
std::string path = ".";
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
struct node_info
{
......@@ -192,7 +196,7 @@ struct onnx_parser
map_actv_funcs.insert(std::make_pair("elu", make_op("elu")));
}
static operation load(const std::string& name, const node_info& info)
operation load(const std::string& name, const node_info& info) const
{
auto op = make_op(name);
auto v = op.to_value();
......@@ -2750,8 +2754,13 @@ struct onnx_parser
return add_broadcastable_binary_op(cd, args[2], "add");
}
void parse_from(std::istream& is)
void parse_from(std::istream& is, std::string name = "")
{
this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty())
this->path = parent_path;
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
{
......@@ -2785,7 +2794,9 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
for(auto&& f : graph.initializer())
{
instructions[f.name()] = prog.add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
......@@ -2916,7 +2927,7 @@ struct onnx_parser
return literal{{t, {size}}, r.begin(), r.end()};
}
static literal parse_value(const onnx::AttributeProto& attr)
literal parse_value(const onnx::AttributeProto& attr) const
{
switch(attr.type())
{
......@@ -2937,9 +2948,17 @@ struct onnx_parser
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
}
static literal parse_tensor(const onnx::TensorProto& t)
literal parse_tensor(const onnx::TensorProto& t) const
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty())
{
const std::string& data_file = t.external_data().at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file);
std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data());
}
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
......@@ -3004,7 +3023,7 @@ struct onnx_parser
return literal{{shape_type, dims}, data.begin(), data.end()};
}
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty())
......@@ -3078,7 +3097,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
program parse_onnx(const std::string& name, const onnx_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
return parse_onnx_from(options, input);
return parse_onnx_from(options, input, name);
}
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
......
......@@ -930,6 +930,38 @@ TEST_CASE(expand_test)
EXPECT(p == prog);
}
migraphx::program create_external_data_prog()
{
migraphx::program p;
migraphx::shape s(migraphx::shape::float_type, {1, 1, 224, 224});
migraphx::shape s2(migraphx::shape::float_type, {10, 1, 11, 11});
std::vector<float> weight_data(1210, 1);
std::vector<float> bias_data(10, 1);
auto bias = p.add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data));
auto weights = p.add_literal(migraphx::literal(s2, weight_data));
auto param = p.add_parameter("input", s);
auto conv = p.add_instruction(migraphx::op::convolution{}, param, weights);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, {1, 10, 214, 214}}, bias);
p.add_instruction(migraphx::op::add{}, conv, bias_bcast);
return p;
}
TEST_CASE(external_data_test)
{
migraphx::program p = create_external_data_prog();
auto prog = optimize_onnx("external_data_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(external_data_diff_path_test)
{
migraphx::program p = create_external_data_prog();
auto prog = optimize_onnx("ext_path/external_data_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment