Commit da6f9c3e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add copy in reshape to make reshape by default perform a copy

parent 7809b341
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/shape_for_each.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -148,7 +149,17 @@ struct reshape ...@@ -148,7 +149,17 @@ struct reshape
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
auto resh = args[0].reshape_lazy(dyn_out.computed_shape);
visit_all(result, resh)([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment