Commit 34e90169 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix the issue of converting standard shape to non-standard shape after inserting contiguous

parent db439b30
......@@ -16,7 +16,7 @@ void auto_contiguous::apply(module& p) const
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
p.replace_instruction(ins, c, true);
}
}
}
......
......@@ -36,7 +36,7 @@ struct instruction
void replace(operation o);
void recompute_shape();
void recompute_shape(bool non_std_stop = false);
void clear_arguments();
......@@ -83,7 +83,7 @@ struct instruction
static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins, bool stop = false);
static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod);
......@@ -139,7 +139,7 @@ struct instruction
// internal
void replace_mod_argument(module_ref old, module_ref new_mod);
void replace(const shape& r);
void replace(const shape& r, bool stop = false);
operation op;
shape result{};
......
......@@ -88,7 +88,7 @@ struct module
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep, bool stop = false);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......
......@@ -34,18 +34,20 @@ instruction::instruction(literal l)
{
}
void instruction::replace(const shape& r)
void instruction::replace(const shape& r, bool stop)
{
if(r != result)
{
result = r;
if(stop and not r.standard()) return;
for(auto&& ins : output)
{
if(ins->name() == "@return")
continue;
assert(ins->name().front() != '@');
ins->recompute_shape();
ins->recompute_shape(stop);
}
}
}
......@@ -57,7 +59,10 @@ void instruction::replace(operation o)
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
void instruction::recompute_shape(bool non_std_stop)
{
replace(compute_shape(op, arguments, module_args), non_std_stop);
}
void instruction::clear_arguments()
{
......@@ -174,11 +179,12 @@ void instruction::backreference(instruction_ref ref)
void instruction::replace_argument(instruction_ref ins,
instruction_ref old,
instruction_ref new_ins)
instruction_ref new_ins,
bool stop)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
ins->recompute_shape(stop);
}
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
......
......@@ -232,7 +232,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep)
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep, bool stop)
{
assert(has_instruction(ins));
assert(has_instruction(rep));
......@@ -255,7 +255,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
// TODO: Check for possible cycles
if(out != rep)
{
instruction::replace_argument(out, ins, rep);
instruction::replace_argument(out, ins, rep, stop);
}
assert(out->valid(begin()));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment