Commit f1c18355 authored by charlie's avatar charlie
Browse files

Fixed using pack() correctly

parent b76a9043
......@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct dyn_output
{
// original shape from the instruction
......@@ -67,26 +106,22 @@ struct compute_output_shape
operator dyn_output() const
{
return unpack(
[](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs);
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
});
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return compute_output_shape{pack(x, output_shape, inputs)};
return {f};
}
#ifdef DOXYGEN
......@@ -243,10 +278,11 @@ auto compute_op(rank<1>,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
make_compute_output_shape(pack(x, output_shape, input)),
input))
{
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
-> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{
return x.compute(make_compute_output_shape(x, output_shape, input), input);
return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -290,9 +326,10 @@ auto compute_op(rank<1>,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -324,12 +361,17 @@ auto compute_op(rank<4>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
F f) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f))
{
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
}
template <class T, class F>
......@@ -340,9 +382,10 @@ auto compute_op(rank<3>,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -352,9 +395,10 @@ auto compute_op(rank<2>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......@@ -364,10 +408,12 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F)
-> decltype(x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs))
F) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs))
{
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......
......@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
......@@ -82,6 +84,15 @@ struct dyn_output
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
......@@ -93,32 +104,24 @@ struct compute_output_shape
{
F ins_inputs;
operator dyn_output<F>() const
operator dyn_output() const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return dyn_output<F>{ins_inputs};
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
});
}
operator shape() const
{
return unpack(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; },
ins_inputs);
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class T>
auto make_compute_output_shape(const T& x, shape output_shape, const std::vector<argument>& inputs)
-> decltype(compute_output_shape{pack(x, output_shape, inputs)})
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return compute_output_shape{pack(x, output_shape, inputs)};
return {f};
}
#ifdef DOXYGEN
......@@ -275,10 +278,11 @@ auto compute_op(rank<1>,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(x, output_shape, input),
make_compute_output_shape(pack(x, output_shape, input)),
input))
{
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output_shape, input), input);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(make_compute_output_shape(x, output_shape, input), input))
-> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{
return x.compute(make_compute_output_shape(x, output_shape, input), input);
return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -322,9 +326,10 @@ auto compute_op(rank<1>,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -356,12 +361,17 @@ auto compute_op(rank<4>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f))
F f) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f))
{
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
}
template <class T, class F>
......@@ -372,9 +382,10 @@ auto compute_op(rank<3>,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
-> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f))
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -384,9 +395,10 @@ auto compute_op(rank<2>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(make_compute_output_shape(x, output, inputs), inputs))
F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{
return x.compute(make_compute_output_shape(x, output, inputs), inputs);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......@@ -396,10 +408,12 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F)
-> decltype(x.compute(auto_any_cast(ctx), compute_output_shape<T>{x, output, inputs}, inputs))
F) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs))
{
return x.compute(auto_any_cast(ctx), make_compute_output_shape(x, output, inputs), inputs);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment