"driver/driver.cpp" did not exist on "39775d484c4d15a5b895edfc9d2323f05ab2d3d4"
Commit 5696ac5f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress; code builds but incomplete

parent 263f1b71
...@@ -56,7 +56,7 @@ struct argmax ...@@ -56,7 +56,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp> #include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
...@@ -107,7 +108,7 @@ struct reduce_op : op_name<Derived> ...@@ -107,7 +108,7 @@ struct reduce_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size()); auto tuned_axes = tune_axes(lens.size());
...@@ -126,6 +127,7 @@ struct reduce_op : op_name<Derived> ...@@ -126,6 +127,7 @@ struct reduce_op : op_name<Derived>
{ {
for(auto axis : tuned_axes) for(auto axis : tuned_axes)
{ {
// todo: how to change for dynamic shapes?
out_lens[axis] = in_lens[axis]; out_lens[axis] = in_lens[axis];
} }
} }
...@@ -151,17 +153,17 @@ struct reduce_op : op_name<Derived> ...@@ -151,17 +153,17 @@ struct reduce_op : op_name<Derived>
static_cast<const Derived&>(*this).output(batch_shape)(val); static_cast<const Derived&>(*this).output(batch_shape)(val);
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ { // todo: what should be different about the computation, if anything?
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size()); auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1); std::vector<std::size_t> batch_lens(dyn_out.computed_shape.lens().size(), 1);
tune_dims(tuned_axes, arg_lens, batch_lens); tune_dims(tuned_axes, arg_lens, batch_lens);
shape batch_shape{output_shape.type(), batch_lens}; shape batch_shape{dyn_out.computed_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i); auto out_idx = dyn_out.computed_shape.multi(i);
this->reduce(input, batch_shape, tuned_axes, out_idx, output); this->reduce(input, batch_shape, tuned_axes, out_idx, output);
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment