Unverified Commit b5633c27 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Driver dim override (#687)



* add driver the option to specify param dims

* clang format

* simplify the command line option

* clang format

* fix cppcheck error

* clang format

* refine unit test to have more code coverage

* clang format

* support the variable number of arguments

* clang format

* remove unnecessary code
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 4ab38dde
......@@ -42,6 +42,7 @@ struct loader
bool brief = false;
std::string output_type;
std::string output;
std::vector<std::string> param_dims;
void parse(argument_parser& ap)
{
......@@ -59,6 +60,11 @@ struct loader
ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(param_dims,
{"--input-dim"},
ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
ap.append(),
ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(output_type,
{"--graphviz", "-g"},
......@@ -81,11 +87,31 @@ struct loader
ap(output, {"--output", "-o"}, ap.help("Output to file."));
}
static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
{
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::string name = "";
for(auto&& x : param_dims_info)
{
if(x[0] == '@')
{
name = x.substr(1);
}
else
{
map_input_dims[name].push_back(value_parser<std::size_t>::apply(x));
}
}
return map_input_dims;
}
program load()
{
program p;
if(model.empty())
{
auto map_input_dims = parse_param_dims(param_dims);
if(file_type.empty())
{
if(ends_with(file, ".onnx"))
......@@ -104,11 +130,12 @@ struct loader
options.default_dim_value = batch;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = true;
options.map_input_dims = map_input_dims;
p = parse_onnx(file, options);
}
else if(file_type == "tf")
{
p = parse_tf(file, tf_options{is_nhwc, batch});
p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims});
}
else if(file_type == "json")
{
......@@ -201,8 +228,8 @@ struct program_params
std::vector<std::string> fill1{};
void parse(argument_parser& ap)
{
ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append());
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append());
ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2));
}
auto generate(const program& p, const target& t, bool offload)
......@@ -472,8 +499,13 @@ using namespace migraphx::driver; // NOLINT
int main(int argc, const char* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
// no argument, print the help infomration by default
if(args.empty())
return 0;
{
args.push_back("-h");
}
auto&& m = get_commands();
auto cmd = args.front();
if(m.count(cmd) > 0)
......@@ -484,5 +516,6 @@ int main(int argc, const char* argv[])
{
run_command<main_command>(args);
}
return 0;
}
......@@ -12,10 +12,12 @@ struct tf_options
{
bool is_nhwc = false;
unsigned int batch_size = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
};
/// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, tf_options = tf_options{});
program parse_tf(const std::string& name, const tf_options& options = tf_options{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -36,6 +36,8 @@ struct tf_parser
module* mm = prog.get_main_module();
bool is_nhwc = true;
unsigned int batch_size = 1;
// Specified dims of inputs
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, op_func> ops;
......@@ -1033,13 +1035,22 @@ struct tf_parser
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
if(is_nhwc and dims.size() >= 4)
if(contains(map_input_dims, name))
{
reorder_data(dims);
dims = map_input_dims.at(name);
}
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
else
{
if(is_nhwc and dims.size() >= 4)
{
reorder_data(dims);
}
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
}
shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s));
}
......@@ -1396,12 +1407,13 @@ struct tf_parser
}
};
program parse_tf(const std::string& name, tf_options options)
program parse_tf(const std::string& name, const tf_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
#ifndef NDEBUG
// Log the program when it can't be parsed
......
#include <iostream>
#include <vector>
#include <unordered_map>
#include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
......@@ -11,9 +12,12 @@
#include <migraphx/tf.hpp>
#include "test.hpp"
migraphx::program parse_tf(const std::string& name, bool is_nhwc)
migraphx::program
parse_tf(const std::string& name,
bool is_nhwc,
const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {})
{
return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1, dim_params});
}
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
......@@ -74,11 +78,11 @@ TEST_CASE(argmax_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = parse_tf("argmax_test.pb", false);
auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
EXPECT(p == prog);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment