"doc/vscode:/vscode.git/clone" did not exist on "9c71bcb0bb825cba5cfb29f1a49871d8e4cb9117"
Unverified Commit bc52a8a8 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Inline subgraph (#802)



* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* unify the compute function signature

* clang format

* make another change

* unify the compute function

* clang format

* remove unnecessary code

* more refinement about the operator compute funciton

* clang format

* add an overload function

* clang format

* add support for axes inputs for sequeeze/unsqueeze/reduce_sum

* clang format

* fix build problems

* backup code changes

* clang format

* Add tuple type to shape class

* Formatting

* fix a bug in parsing quantizelinear operator

* clang format

* fix a cppcheck error

* disable different versions of unit tests for different onnx version

* clang format

* upgrade onnx to 1.8

* update onnx to 1.8.1

* disable two more real models

* clang format

* Make data member private

* Formatting

* Add sub arguments

* Formatting

* Trun clang format off

* Disable clang-format

* fix review comments

* fix the function of assign axes in parsing the squeeze operator

* add unit tests and fix a bug

* clang format

* fix review comments

* clang format

* fix a build error

* backup code changes

* clang format

* add more unit tests and add parsing opset version

* clang format

* Improve visiting tuples

* Formatting

* fix cppcheck error

* adding installing the onnx package

* resolve no protobuf compiler

* add an inline subgraph pass

* clang format

* Add more argument tests

* Formatting

* Handle tuple in load

* Formatting

* code backup

* clang format

* Remove .o files

* Add tuple type to api

* Formatting

* fix build errors

* clang format

* code backup

* code backup

* add unit tests for the inline subgraph

* clang format

* refine the inline subgraph and parse if operator

* clang format

* fix cppcheck issue

* clang format

* add unit test for inline subgraph pass

* clang format

* fix format issue

* remove the context from the if operator

* clang format

* simplify the compute functions

* Fix tidy warnings

* fix cppcheck error

* clang format

* fix cppcheck error

* Fix tidy warnings

* fix a cppcheck error

* clang format

* Add a test for share method

* Formatting

* Add a test cpp_type

* add unit tests for more code coverage

* clang format

* add unit tests to have more code coverage

* clang format

* try a comment in jenkins build

* include the install onnnx line

* code backup

* reorder the dependenciesd installed

* refine dockerfile

* fix review comments

* clang format

* remove unnecessary overload function

* fix cppcheck error

* change back the argument test

* Suppress tidy warning

* add the operator get_tuple_elem

* clang format

* add get_tuple_elem to operator include file

* chang if to support multiple operation outputs

* clang format

* optimize inline subgraph

* clang format

* code backup

* clang format

* fix bug

* refine unit tests for tuple output of the if operator

* clang format

* refine a instruction replacement code

* add a unit test and sort all the unit tests alphabetically

* fix cppcheck error

* add more unit tests for multiple op outputs

* clang format

* fix cppcheck error

* Update pass manager to get modules after every pass

* more unit test to cover more scenarios

* clang format

* fixed a bug in a unit test

* add more tests

* clang format

* add more unit tests to have more code coverage

* fix a bug in a unit test

* Add program overload for module

* Formatting

* Hash modules for quicker lookup of modules

* Bump file version

* Add methods to remove modules

* Formatting

* add the tuple type to the support list

* Eliminate unused modules

* Formatting

* Fix test errors

* Foramtting

* Fix tidy issues

* fix problem related to inline subgraph

* clang format

* fix review comments

* fix review comments

* fix review comments

* fix review comments

* clang format

* fix a unit test

* one more code change

* remove an optimization related to the if operator

* clang format

* fix review comments
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e00479af
......@@ -1668,7 +1668,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1730,7 +1731,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1796,7 +1798,8 @@ TEST_CASE(if_pl_test)
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto outline = mm->add_outline(s);
mm->add_return({outline, ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({outline, r});
return p;
};
......
......@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
......@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
}
......
......@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
......@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
}
template <class T>
auto compute_op(rank<2>,
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -207,50 +199,106 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, ctx, output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
template <class T, class F>
argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<2>{}, x, output_shape, input);
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<1>,
auto compute_op(rank<3>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f))
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(inputs, module_args, f);
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F)
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
......@@ -258,11 +306,13 @@ argument
template <class T, class F>
argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<1>{}, x, inputs, module_args, f);
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -447,10 +497,22 @@ bool is_borrowed_op(const T&)
virtual(
'compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&',
run =
'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True,
default = 'detail::compute_op'),
virtual(
'compute',
returns = 'argument',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&',
run =
'std::function<std::vector<argument>(module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>',
'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True,
default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment