Commit 6958e2bf authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Update lowering and reshape to work for reshape_copy

Make reshape work like contiguous to perform the copy and then add
proper aliasing in lowering if we're unable to perform a replace instruction
parent d0174a6c
...@@ -46,8 +46,6 @@ struct reshape ...@@ -46,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape dyn_compute_shape(shape s0) const shape dyn_compute_shape(shape s0) const
...@@ -133,18 +131,16 @@ struct reshape ...@@ -133,18 +131,16 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto s0 = inputs.front();
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic()) if(s0.dynamic())
{ {
return dyn_compute_shape(s0); return s0;
} }
else else
{ {
return static_compute_shape(inputs, n_neg_dims); auto t = s0.type();
} return {t, dims};
}
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
...@@ -152,8 +148,6 @@ struct reshape ...@@ -152,8 +148,6 @@ struct reshape
assert(dyn_out.computed_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
//auto resh = args[0](dyn_out.computed_shape);
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()); output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
...@@ -162,7 +156,6 @@ struct reshape ...@@ -162,7 +156,6 @@ struct reshape
return result; return result;
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -90,7 +90,6 @@ struct miopen_apply ...@@ -90,7 +90,6 @@ struct miopen_apply
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
add_generic_op("reshape_lazy");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
...@@ -116,7 +115,7 @@ struct miopen_apply ...@@ -116,7 +115,7 @@ struct miopen_apply
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op(); add_select_module_op();
//add_reshape_lazy_op(); add_reshape_lazy_op();
} }
void copy_params() const void copy_params() const
...@@ -382,11 +381,11 @@ struct miopen_apply ...@@ -382,11 +381,11 @@ struct miopen_apply
/** /**
* Adds reshape lazy to reshape ops that can be aliased instead of copied * Adds reshape lazy to reshape ops that can be aliased instead of copied
*/ */
/*void add_reshape_lazy_op() void add_reshape_lazy_op()
{ {
apply_map.emplace("reshape", [=](instruction_ref ins) { apply_map.emplace("reshape", [=](instruction_ref ins) {
/* Attempt lazy reshape to allow for aliasing. Potentially throws in get_shape if unable /* Attempt lazy reshape to allow for aliasing. Potentially throws in get_shape if unable
* to alias * to alias */
try try
{ {
auto lazy_ins = mod->replace_instruction( auto lazy_ins = mod->replace_instruction(
...@@ -398,10 +397,14 @@ struct miopen_apply ...@@ -398,10 +397,14 @@ struct miopen_apply
} }
catch(...) catch(...)
{ {
return ins; auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
return mod->replace_instruction(ins, make_op("gpu::contiguous"), refs);
} }
}); });
}*/ }
}; };
void lowering::apply(module_pass_manager& mpm) const void lowering::apply(module_pass_manager& mpm) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment