Commit f02f5d98 authored by charlie's avatar charlie
Browse files

Initial

parent e0cb7b9a
...@@ -222,7 +222,7 @@ template <class T> ...@@ -222,7 +222,7 @@ template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(make_compute_output_shape(x, output_shape, input), input);
} }
template <class T> template <class T>
...@@ -1278,6 +1278,59 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1278,6 +1278,59 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return not(x == y); } inline bool operator!=(const operation& x, const operation& y) { return not(x == y); }
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input) compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{ {
......
...@@ -234,8 +234,9 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -234,8 +234,9 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
return op.normalize_compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, migraphx::dyn_output dyn_output, std::vector<argument> args) const
{ {
shape output_shape;
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
if(op.padding_mode != op::padding_mode_t::default_) if(op.padding_mode != op::padding_mode_t::default_)
{ {
...@@ -250,12 +251,8 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -250,12 +251,8 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
} }
else else
{ {
padding = op.padding; padding = op.padding;
if(output_shape.dynamic()) output_shape = dyn_output.get_output_shape();
{
output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
} }
argument result{output_shape}; argument result{output_shape};
......
...@@ -222,7 +222,7 @@ template <class T> ...@@ -222,7 +222,7 @@ template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(make_compute_output_shape(x, output_shape, input), input);
} }
template <class T> template <class T>
...@@ -563,6 +563,59 @@ lifetime get_lifetime_op(const T&) ...@@ -563,6 +563,59 @@ lifetime get_lifetime_op(const T&)
return not(x == y); return not(x == y);
} }
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input) compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment