Unverified Commit 4351f46c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Merge branch 'develop' into scatter-op

parents 27c0ae08 bc52a8a8
...@@ -1668,7 +1668,8 @@ TEST_CASE(if_literal_test) ...@@ -1668,7 +1668,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
}; };
...@@ -1730,7 +1731,8 @@ TEST_CASE(if_param_test) ...@@ -1730,7 +1731,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
}; };
...@@ -1796,7 +1798,8 @@ TEST_CASE(if_pl_test) ...@@ -1796,7 +1798,8 @@ TEST_CASE(if_pl_test)
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto outline = mm->add_outline(s); auto outline = mm->add_outline(s);
mm->add_return({outline, ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({outline, r});
return p; return p;
}; };
......
...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal> ...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp> ...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2}); else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p; return p;
} }
......
...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param> ...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x, ...@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
} }
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<1>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -188,14 +188,6 @@ auto compute_op(rank<2>, ...@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -207,50 +199,106 @@ template <class T> ...@@ -207,50 +199,106 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<2>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&) argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T, class F>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{ {
return compute_op(rank<2>{}, x, output_shape, input); return compute_op(rank<1>{}, x, output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
auto compute_op(rank<1>, auto compute_op(rank<3>,
const T& x, const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f)) F f) -> decltype(x.compute(output, inputs, module_args, f))
{ {
return x.compute(inputs, module_args, f); return x.compute(output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
argument auto compute_op(rank<2>,
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F) const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
...@@ -258,11 +306,13 @@ argument ...@@ -258,11 +306,13 @@ argument
template <class T, class F> template <class T, class F>
argument compute_op(const T& x, argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
{ {
return compute_op(rank<1>{}, x, inputs, module_args, f); return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
} }
template <class T> template <class T>
...@@ -447,10 +497,22 @@ bool is_borrowed_op(const T&) ...@@ -447,10 +497,22 @@ bool is_borrowed_op(const T&)
virtual( virtual(
'compute', 'compute',
returns = 'argument', returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&',
run =
'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True,
default = 'detail::compute_op'),
virtual(
'compute',
returns = 'argument',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&', module_args = 'const std::vector<module_ref>&',
run = run =
'std::function<std::vector<argument>(module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>', 'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True, const = True,
default = 'detail::compute_op'), default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'), virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment