Commit 68c17b1b authored by charlie's avatar charlie
Browse files

Still broken, figuring things out

parent f02f5d98
......@@ -32,6 +32,7 @@
#include <utility>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
......@@ -46,6 +47,48 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context;
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{
return compute_output_shape{pack(x, output_shape, inputs)};
}
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
......@@ -199,9 +242,11 @@ auto compute_op(rank<1>,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
-> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
}
template <class T>
......@@ -220,7 +265,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
{
return x.compute(make_compute_output_shape(x, output_shape, input), input);
}
......@@ -244,9 +289,10 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -278,9 +324,12 @@ auto compute_op(rank<4>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -290,9 +339,10 @@ auto compute_op(rank<3>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -302,9 +352,9 @@ auto compute_op(rank<2>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
{
return x.compute(output, inputs);
return x.compute(make_compute_output_shape(x, output, inputs), inputs);
}
template <class T, class F>
......@@ -314,9 +364,10 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
F)
-> decltype(x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
}
template <class T, class F>
......@@ -1278,59 +1329,6 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return not(x == y); }
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
......
......@@ -214,6 +214,9 @@ struct shape
/// Return true if the shape is dynamic
bool dynamic() const;
/// Returns true if the shape is empty
bool empty() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......
......@@ -443,6 +443,8 @@ std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::empty() const { return max_lens().empty(); }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
......
......@@ -79,8 +79,9 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(ctx,
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}),
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
......
......@@ -252,7 +252,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
else
{
padding = op.padding;
output_shape = dyn_output.get_output_shape();
output_shape = dyn_output.computed_shape;
}
argument result{output_shape};
......
......@@ -32,6 +32,7 @@
#include <utility>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
......@@ -46,6 +47,80 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context;
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output<F>() const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return dyn_output<F>{ins_inputs};
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{
return compute_output_shape{pack(x, output_shape, inputs)};
}
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
......@@ -199,9 +274,11 @@ auto compute_op(rank<1>,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
-> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
}
template <class T>
......@@ -220,7 +297,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
{
return x.compute(make_compute_output_shape(x, output_shape, input), input);
}
......@@ -244,9 +321,10 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -278,9 +356,12 @@ auto compute_op(rank<4>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -290,9 +371,10 @@ auto compute_op(rank<3>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
}
template <class T, class F>
......@@ -302,9 +384,9 @@ auto compute_op(rank<2>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
{
return x.compute(output, inputs);
return x.compute(make_compute_output_shape(x, output, inputs), inputs);
}
template <class T, class F>
......@@ -314,9 +396,10 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
F)
-> decltype(x.compute(auto_any_cast(ctx), compute_output_shape<T>{x, output, inputs}, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
}
template <class T, class F>
......@@ -563,59 +646,6 @@ lifetime get_lifetime_op(const T&)
return not(x == y);
}
// used for dynamic operators
struct dyn_output
{
// original instruction output shape
shape ins_shape;
std::function<shape()> compute_shape;
shape get_output_shape()
{
if(output_shape.element_space() == 0)
{
output_shape = compute_shape();
}
return output_shape;
}
private:
// shape computed at eval time using input arguments
shape output_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& args) {
return dyn_output{ins_shape, [&]() { compute_shape(x, to_shapes(args)); }};
},
ins_inputs);
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape ins_shape, const std::vector<argument>& input)
{
return compute_output_shape{pack(x, ins_shape, input)};
}
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment