Commit 2cf7ae45 authored by charlie's avatar charlie
Browse files

Fix eliminate_contiguous pass

parent 91e3efee
......@@ -140,17 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
// Perform evaluations in parallel
// Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
std::vector<shape> prev_shape = {prev->get_shape()};
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
// compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
auto co_shape = make_compute_output_shape(pack(c, prev_shape, prev_eval));
literals[i] = c.compute(co_shape, {prev->eval()});
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
});
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment