"vscode:/vscode.git/clone" did not exist on "330e1e18ae53ae17a826a7f6e71095386931e8cb"
Commit f1c18355 authored by charlie's avatar charlie
Browse files

Fixed using pack() correctly

parent b76a9043
...@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context; struct context;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct dyn_output struct dyn_output
{ {
// original shape from the instruction // original shape from the instruction
...@@ -67,26 +106,22 @@ struct compute_output_shape ...@@ -67,26 +106,22 @@ struct compute_output_shape
operator dyn_output() const operator dyn_output() const
{ {
return unpack( return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
[](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))}; return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
}, });
ins_inputs);
} }
operator shape() const operator shape() const
{ {
return unpack( return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; }, [](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
ins_inputs);
} }
}; };
template <class T> template <class F>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs) compute_output_shape<F> make_compute_output_shape(F f)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{ {
return compute_output_shape{pack(x, output_shape, inputs)}; return {f};
} }
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -243,10 +278,11 @@ auto compute_op(rank<1>, ...@@ -243,10 +278,11 @@ auto compute_op(rank<1>,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input), make_compute_output_shape(pack(x, output_shape, input)),
input)) input))
{ {
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input)) -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{ {
return x.compute(make_compute_output_shape(x, output_shape, input), input); return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -290,9 +326,10 @@ auto compute_op(rank<1>, ...@@ -290,9 +326,10 @@ auto compute_op(rank<1>,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f)) -> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -324,12 +361,17 @@ auto compute_op(rank<4>, ...@@ -324,12 +361,17 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f) -> decltype(x.compute(auto_any_cast(ctx),
-> decltype(x.compute( make_compute_output_shape(pack(x, output, inputs)),
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f)) inputs,
module_args,
f))
{ {
return x.compute( return x.compute(auto_any_cast(ctx),
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f); make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
} }
template <class T, class F> template <class T, class F>
...@@ -340,9 +382,10 @@ auto compute_op(rank<3>, ...@@ -340,9 +382,10 @@ auto compute_op(rank<3>,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f)) -> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -352,9 +395,10 @@ auto compute_op(rank<2>, ...@@ -352,9 +395,10 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs)) F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -364,10 +408,12 @@ auto compute_op(rank<1>, ...@@ -364,10 +408,12 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) F) -> decltype(x.compute(auto_any_cast(ctx),
-> decltype(x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs)) make_compute_output_shape(pack(x, output, inputs)),
inputs))
{ {
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
......
...@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context; struct context;
/*
template <class F>
struct dyn_output struct dyn_output
{ {
F ins_inputs; F ins_inputs;
...@@ -82,6 +84,15 @@ struct dyn_output ...@@ -82,6 +84,15 @@ struct dyn_output
// shape computed at eval time using input arguments // shape computed at eval time using input arguments
shape computed_shape; shape computed_shape;
}; };
*/
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/** /**
* Handle dynamic and static shape at evaluation time. * Handle dynamic and static shape at evaluation time.
...@@ -93,32 +104,24 @@ struct compute_output_shape ...@@ -93,32 +104,24 @@ struct compute_output_shape
{ {
F ins_inputs; F ins_inputs;
operator dyn_output<F>() const operator dyn_output() const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{ {
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))}; return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
}, });
ins_inputs
);
*/
return dyn_output<F>{ins_inputs};
} }
operator shape() const operator shape() const
{ {
return unpack( return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; }, [](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
ins_inputs);
} }
}; };
template <class T> template <class F>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs) compute_output_shape<F> make_compute_output_shape(F f)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
{ {
return compute_output_shape{pack(x, output_shape, inputs)}; return {f};
} }
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -275,10 +278,11 @@ auto compute_op(rank<1>, ...@@ -275,10 +278,11 @@ auto compute_op(rank<1>,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input), make_compute_output_shape(pack(x, output_shape, input)),
input)) input))
{ {
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input)) -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{ {
return x.compute(make_compute_output_shape(x, output_shape, input), input); return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -322,9 +326,10 @@ auto compute_op(rank<1>, ...@@ -322,9 +326,10 @@ auto compute_op(rank<1>,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f)) -> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -356,12 +361,17 @@ auto compute_op(rank<4>, ...@@ -356,12 +361,17 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f) -> decltype(x.compute(auto_any_cast(ctx),
-> decltype(x.compute( make_compute_output_shape(pack(x, output, inputs)),
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f)) inputs,
module_args,
f))
{ {
return x.compute( return x.compute(auto_any_cast(ctx),
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f); make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
} }
template <class T, class F> template <class T, class F>
...@@ -372,9 +382,10 @@ auto compute_op(rank<3>, ...@@ -372,9 +382,10 @@ auto compute_op(rank<3>,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f)) -> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -384,9 +395,10 @@ auto compute_op(rank<2>, ...@@ -384,9 +395,10 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs)) F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{ {
return x.compute(make_compute_output_shape(x, output, inputs), inputs); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -396,10 +408,12 @@ auto compute_op(rank<1>, ...@@ -396,10 +408,12 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) F) -> decltype(x.compute(auto_any_cast(ctx),
-> decltype(x.compute(auto_any_cast(ctx), compute_output_shape<T>{x, output, inputs}, inputs)) make_compute_output_shape(pack(x, output, inputs)),
inputs))
{ {
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment