Commit 68c17b1b authored by charlie's avatar charlie
Browse files

Still broken, figuring things out

parent f02f5d98
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -46,6 +47,48 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -46,6 +47,48 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context; struct context;
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{
return compute_output_shape{pack(x, output_shape, inputs)};
}
#ifdef DOXYGEN #ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All /// The operation interface represents an action an instruction will perform. All
...@@ -199,9 +242,11 @@ auto compute_op(rank<1>, ...@@ -199,9 +242,11 @@ auto compute_op(rank<1>,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input)) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
input))
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
} }
template <class T> template <class T>
...@@ -220,7 +265,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -220,7 +265,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
{ {
return x.compute(make_compute_output_shape(x, output_shape, input), input); return x.compute(make_compute_output_shape(x, output_shape, input), input);
} }
...@@ -244,9 +289,10 @@ auto compute_op(rank<1>, ...@@ -244,9 +289,10 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -278,9 +324,12 @@ auto compute_op(rank<4>, ...@@ -278,9 +324,12 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -290,9 +339,10 @@ auto compute_op(rank<3>, ...@@ -290,9 +339,10 @@ auto compute_op(rank<3>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -302,9 +352,9 @@ auto compute_op(rank<2>, ...@@ -302,9 +352,9 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs)) F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
{ {
return x.compute(output, inputs); return x.compute(make_compute_output_shape(x, output, inputs), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -314,9 +364,10 @@ auto compute_op(rank<1>, ...@@ -314,9 +364,10 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) F)
-> decltype(x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs))
{ {
return x.compute(auto_any_cast(ctx), output, inputs); return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -1278,59 +1329,6 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1278,59 +1329,6 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return not(x == y); } inline bool operator!=(const operation& x, const operation& y) { return not(x == y); }
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input) compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{ {
......
...@@ -214,6 +214,9 @@ struct shape ...@@ -214,6 +214,9 @@ struct shape
/// Return true if the shape is dynamic /// Return true if the shape is dynamic
bool dynamic() const; bool dynamic() const;
/// Returns true if the shape is empty
bool empty() const;
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......
...@@ -443,6 +443,8 @@ std::string shape::type_string() const { return name(this->type()); } ...@@ -443,6 +443,8 @@ std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); } bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::empty() const { return max_lens().empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const std::vector<std::size_t> shape::min_lens() const
......
...@@ -79,8 +79,9 @@ struct scatternd_compiler : compiler<scatternd_compiler> ...@@ -79,8 +79,9 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{ {
assert(starts_with(op.name(), "scatternd_")); assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10); auto reduction = op.name().substr(10);
return insert(compile_op(ctx, return insert(compile_op(
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}), ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}})); {{"reduction", reduction}}));
} }
......
...@@ -252,7 +252,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -252,7 +252,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
else else
{ {
padding = op.padding; padding = op.padding;
output_shape = dyn_output.get_output_shape(); output_shape = dyn_output.computed_shape;
} }
argument result{output_shape}; argument result{output_shape};
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -46,6 +47,80 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -46,6 +47,80 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context; struct context;
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output<F>() const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return dyn_output<F>{ins_inputs};
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{
return compute_output_shape{pack(x, output_shape, inputs)};
}
#ifdef DOXYGEN #ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All /// The operation interface represents an action an instruction will perform. All
...@@ -199,9 +274,11 @@ auto compute_op(rank<1>, ...@@ -199,9 +274,11 @@ auto compute_op(rank<1>,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input)) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
input))
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
} }
template <class T> template <class T>
...@@ -220,7 +297,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -220,7 +297,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
{ {
return x.compute(make_compute_output_shape(x, output_shape, input), input); return x.compute(make_compute_output_shape(x, output_shape, input), input);
} }
...@@ -244,9 +321,10 @@ auto compute_op(rank<1>, ...@@ -244,9 +321,10 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -278,9 +356,12 @@ auto compute_op(rank<4>, ...@@ -278,9 +356,12 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -290,9 +371,10 @@ auto compute_op(rank<3>, ...@@ -290,9 +371,10 @@ auto compute_op(rank<3>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -302,9 +384,9 @@ auto compute_op(rank<2>, ...@@ -302,9 +384,9 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs)) F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
{ {
return x.compute(output, inputs); return x.compute(make_compute_output_shape(x, output, inputs), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -314,9 +396,10 @@ auto compute_op(rank<1>, ...@@ -314,9 +396,10 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) F)
-> decltype(x.compute(auto_any_cast(ctx), compute_output_shape<T>{x, output, inputs}, inputs))
{ {
return x.compute(auto_any_cast(ctx), output, inputs); return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -563,59 +646,6 @@ lifetime get_lifetime_op(const T&) ...@@ -563,59 +646,6 @@ lifetime get_lifetime_op(const T&)
return not(x == y); return not(x == y);
} }
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input) compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment