"vscode:/vscode.git/clone" did not exist on "ad0ab35784cb456548d589a77449df3b89a999cf"
Unverified Commit 1af75182 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #255 from ROCmSoftwarePlatform/eval-check

Fix slowness in eval
parents 24d68767 f006b0a9
......@@ -72,7 +72,9 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
bool can_eval() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......
......@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
argument instruction::eval() const
bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{
if(op.name() == "@literal")
{
......@@ -170,14 +187,13 @@ argument instruction::eval() const
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
if(check_eval and not this->can_eval())
return {};
args.push_back(a);
}
std::vector<argument> args;
std::transform(this->inputs().begin(),
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return op.compute(result, args);
}
return {};
......
......@@ -22,22 +22,32 @@ bool skip_propogate(instruction_ref ins)
void propagate_constant::apply(program& p) const
{
for(auto i : iterator_for(p))
{
if(i->name() != "@literal")
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
if(not skip_propogate(ins))
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
auto r = ins->eval();
if(skip_propogate(child))
{
self(child);
continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == ins->get_shape());
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(ins, l);
return;
self(p.replace_instruction(child, l));
}
}
std::unordered_set<instruction_ref> children(ins->inputs().begin(), ins->inputs().end());
for(auto child : children)
self(child);
})(std::prev(p.end()));
})(i);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment