Commit 5696ac5f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress; code builds but incomplete

parent 263f1b71
......@@ -56,7 +56,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
......@@ -107,7 +108,7 @@ struct reduce_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
......@@ -126,6 +127,7 @@ struct reduce_op : op_name<Derived>
{
for(auto axis : tuned_axes)
{
// todo: how to change for dynamic shapes?
out_lens[axis] = in_lens[axis];
}
}
......@@ -151,17 +153,17 @@ struct reduce_op : op_name<Derived>
static_cast<const Derived&>(*this).output(batch_shape)(val);
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ // todo: what should be different about the computation, if anything?
argument result{dyn_out.computed_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
std::vector<std::size_t> batch_lens(dyn_out.computed_shape.lens().size(), 1);
tune_dims(tuned_axes, arg_lens, batch_lens);
shape batch_shape{output_shape.type(), batch_lens};
shape batch_shape{dyn_out.computed_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto out_idx = dyn_out.computed_shape.multi(i);
this->reduce(input, batch_shape, tuned_axes, out_idx, output);
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment