"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "9c71bcb0bb825cba5cfb29f1a49871d8e4cb9117"
Unverified Commit b5633c27 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Driver dim override (#687)



* add driver the option to specify param dims

* clang format

* simplify the command line option

* clang format

* fix cppcheck error

* clang format

* refine unit test to have more code coverage

* clang format

* support the variable number of arguments

* clang format

* remove unnecessary code
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 4ab38dde
...@@ -42,6 +42,7 @@ struct loader ...@@ -42,6 +42,7 @@ struct loader
bool brief = false; bool brief = false;
std::string output_type; std::string output_type;
std::string output; std::string output;
std::vector<std::string> param_dims;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -59,6 +60,11 @@ struct loader ...@@ -59,6 +60,11 @@ struct loader
ap.set_value(true)); ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(param_dims,
{"--input-dim"},
ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
ap.append(),
ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(output_type, ap(output_type,
{"--graphviz", "-g"}, {"--graphviz", "-g"},
...@@ -81,11 +87,31 @@ struct loader ...@@ -81,11 +87,31 @@ struct loader
ap(output, {"--output", "-o"}, ap.help("Output to file.")); ap(output, {"--output", "-o"}, ap.help("Output to file."));
} }
static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
{
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::string name = "";
for(auto&& x : param_dims_info)
{
if(x[0] == '@')
{
name = x.substr(1);
}
else
{
map_input_dims[name].push_back(value_parser<std::size_t>::apply(x));
}
}
return map_input_dims;
}
program load() program load()
{ {
program p; program p;
if(model.empty()) if(model.empty())
{ {
auto map_input_dims = parse_param_dims(param_dims);
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) if(ends_with(file, ".onnx"))
...@@ -104,11 +130,12 @@ struct loader ...@@ -104,11 +130,12 @@ struct loader
options.default_dim_value = batch; options.default_dim_value = batch;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = true; options.print_program_on_error = true;
options.map_input_dims = map_input_dims;
p = parse_onnx(file, options); p = parse_onnx(file, options);
} }
else if(file_type == "tf") else if(file_type == "tf")
{ {
p = parse_tf(file, tf_options{is_nhwc, batch}); p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims});
} }
else if(file_type == "json") else if(file_type == "json")
{ {
...@@ -201,8 +228,8 @@ struct program_params ...@@ -201,8 +228,8 @@ struct program_params
std::vector<std::string> fill1{}; std::vector<std::string> fill1{};
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append()); ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append()); ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2));
} }
auto generate(const program& p, const target& t, bool offload) auto generate(const program& p, const target& t, bool offload)
...@@ -472,8 +499,13 @@ using namespace migraphx::driver; // NOLINT ...@@ -472,8 +499,13 @@ using namespace migraphx::driver; // NOLINT
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
std::vector<std::string> args(argv + 1, argv + argc); std::vector<std::string> args(argv + 1, argv + argc);
// no argument, print the help infomration by default
if(args.empty()) if(args.empty())
return 0; {
args.push_back("-h");
}
auto&& m = get_commands(); auto&& m = get_commands();
auto cmd = args.front(); auto cmd = args.front();
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
...@@ -484,5 +516,6 @@ int main(int argc, const char* argv[]) ...@@ -484,5 +516,6 @@ int main(int argc, const char* argv[])
{ {
run_command<main_command>(args); run_command<main_command>(args);
} }
return 0; return 0;
} }
...@@ -12,10 +12,12 @@ struct tf_options ...@@ -12,10 +12,12 @@ struct tf_options
{ {
bool is_nhwc = false; bool is_nhwc = false;
unsigned int batch_size = 1; unsigned int batch_size = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
}; };
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, tf_options = tf_options{}); program parse_tf(const std::string& name, const tf_options& options = tf_options{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -36,6 +36,8 @@ struct tf_parser ...@@ -36,6 +36,8 @@ struct tf_parser
module* mm = prog.get_main_module(); module* mm = prog.get_main_module();
bool is_nhwc = true; bool is_nhwc = true;
unsigned int batch_size = 1; unsigned int batch_size = 1;
// Specified dims of inputs
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
...@@ -1033,13 +1035,22 @@ struct tf_parser ...@@ -1033,13 +1035,22 @@ struct tf_parser
attribute_map input_attrs = get_attributes(input); attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
if(is_nhwc and dims.size() >= 4)
if(contains(map_input_dims, name))
{ {
reorder_data(dims); dims = map_input_dims.at(name);
} }
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) { else
return static_cast<int>(dim) <= 0 ? batch_size : dim; {
}); if(is_nhwc and dims.size() >= 4)
{
reorder_data(dims);
}
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
}
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(mm->add_parameter(name, s)); instructions[name] = to_nhwc(mm->add_parameter(name, s));
} }
...@@ -1396,12 +1407,13 @@ struct tf_parser ...@@ -1396,12 +1407,13 @@ struct tf_parser
} }
}; };
program parse_tf(const std::string& name, tf_options options) program parse_tf(const std::string& name, const tf_options& options)
{ {
std::fstream input(name.c_str(), std::ios::in | std::ios::binary); std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser; tf_parser parser;
parser.is_nhwc = options.is_nhwc; parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size; parser.batch_size = options.batch_size;
parser.map_input_dims = options.map_input_dims;
#ifndef NDEBUG #ifndef NDEBUG
// Log the program when it can't be parsed // Log the program when it can't be parsed
......
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <unordered_map>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
...@@ -11,9 +12,12 @@ ...@@ -11,9 +12,12 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program parse_tf(const std::string& name, bool is_nhwc) migraphx::program
parse_tf(const std::string& name,
bool is_nhwc,
const std::unordered_map<std::string, std::vector<std::size_t>>& dim_params = {})
{ {
return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1}); return migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1, dim_params});
} }
migraphx::program optimize_tf(const std::string& name, bool is_nhwc) migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
...@@ -74,11 +78,11 @@ TEST_CASE(argmax_test) ...@@ -74,11 +78,11 @@ TEST_CASE(argmax_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0); auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins); mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = parse_tf("argmax_test.pb", false); auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}});
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment