Unverified Commit 4351f46c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Merge branch 'develop' into scatter-op

parents 27c0ae08 bc52a8a8
......@@ -1668,7 +1668,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1730,7 +1731,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1796,7 +1798,8 @@ TEST_CASE(if_pl_test)
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto outline = mm->add_outline(s);
mm->add_return({outline, ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({outline, r});
return p;
};
......
......@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
......@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
}
......
......@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
......@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
}
template <class T>
auto compute_op(rank<2>,
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -207,50 +199,106 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, ctx, output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
template <class T, class F>
argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<2>{}, x, output_shape, input);
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<1>,
auto compute_op(rank<3>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f))
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(inputs, module_args, f);
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F)
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
......@@ -258,11 +306,13 @@ argument
template <class T, class F>
argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<1>{}, x, inputs, module_args, f);
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -447,10 +497,22 @@ bool is_borrowed_op(const T&)
virtual(
'compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&',
run =
'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True,
default = 'detail::compute_op'),
virtual(
'compute',
returns = 'argument',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&',
module_args = 'const std::vector<module_ref>&',
run =
'std::function<std::vector<argument>(module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>',
'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
const = True,
default = 'detail::compute_op'),
virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment