Commit b377012f authored by charlie's avatar charlie
Browse files

Move rest of stuff to dyn_output.hpp

parent 558afddb
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,6 +38,38 @@ struct dyn_output ...@@ -37,6 +38,38 @@ struct dyn_output
shape computed_shape; shape computed_shape;
}; };
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -96,38 +96,6 @@ bool has_finalize(const operation& x); ...@@ -96,38 +96,6 @@ bool has_finalize(const operation& x);
#else #else
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
namespace detail { namespace detail {
namespace operation_operators { namespace operation_operators {
......
...@@ -96,38 +96,6 @@ bool has_finalize(const operation& x); ...@@ -96,38 +96,6 @@ bool has_finalize(const operation& x);
#else #else
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
namespace detail { namespace detail {
namespace operation_operators { namespace operation_operators {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment