Commit 34e90169 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix the issue of converting standard shape to non-standard shape after inserting contiguous

parent db439b30
...@@ -16,7 +16,7 @@ void auto_contiguous::apply(module& p) const ...@@ -16,7 +16,7 @@ void auto_contiguous::apply(module& p) const
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c, true);
} }
} }
} }
......
...@@ -36,7 +36,7 @@ struct instruction ...@@ -36,7 +36,7 @@ struct instruction
void replace(operation o); void replace(operation o);
void recompute_shape(); void recompute_shape(bool non_std_stop = false);
void clear_arguments(); void clear_arguments();
...@@ -83,7 +83,7 @@ struct instruction ...@@ -83,7 +83,7 @@ struct instruction
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins); static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins, bool stop = false);
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod); static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
...@@ -139,7 +139,7 @@ struct instruction ...@@ -139,7 +139,7 @@ struct instruction
// internal // internal
void replace_mod_argument(module_ref old, module_ref new_mod); void replace_mod_argument(module_ref old, module_ref new_mod);
void replace(const shape& r); void replace(const shape& r, bool stop = false);
operation op; operation op;
shape result{}; shape result{};
......
...@@ -88,7 +88,7 @@ struct module ...@@ -88,7 +88,7 @@ struct module
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST; std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep, bool stop = false);
instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last); instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......
...@@ -34,18 +34,20 @@ instruction::instruction(literal l) ...@@ -34,18 +34,20 @@ instruction::instruction(literal l)
{ {
} }
void instruction::replace(const shape& r) void instruction::replace(const shape& r, bool stop)
{ {
if(r != result) if(r != result)
{ {
result = r; result = r;
if(stop and not r.standard()) return;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
continue; continue;
assert(ins->name().front() != '@'); assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape(stop);
} }
} }
} }
...@@ -57,7 +59,10 @@ void instruction::replace(operation o) ...@@ -57,7 +59,10 @@ void instruction::replace(operation o)
recompute_shape(); recompute_shape();
} }
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); } void instruction::recompute_shape(bool non_std_stop)
{
replace(compute_shape(op, arguments, module_args), non_std_stop);
}
void instruction::clear_arguments() void instruction::clear_arguments()
{ {
...@@ -174,11 +179,12 @@ void instruction::backreference(instruction_ref ref) ...@@ -174,11 +179,12 @@ void instruction::backreference(instruction_ref ref)
void instruction::replace_argument(instruction_ref ins, void instruction::replace_argument(instruction_ref ins,
instruction_ref old, instruction_ref old,
instruction_ref new_ins) instruction_ref new_ins,
bool stop)
{ {
ins->replace_argument(old, new_ins); ins->replace_argument(old, new_ins);
backreference(ins); backreference(ins);
ins->recompute_shape(); ins->recompute_shape(stop);
} }
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod) void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
......
...@@ -232,7 +232,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -232,7 +232,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
return ins; return ins;
} }
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep, bool stop)
{ {
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(has_instruction(rep)); assert(has_instruction(rep));
...@@ -255,7 +255,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -255,7 +255,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
// TODO: Check for possible cycles // TODO: Check for possible cycles
if(out != rep) if(out != rep)
{ {
instruction::replace_argument(out, ins, rep); instruction::replace_argument(out, ins, rep, stop);
} }
assert(out->valid(begin())); assert(out->valid(begin()));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment