Commit 2cf7ae45 authored by charlie's avatar charlie
Browse files

Fix eliminate_contiguous pass

parent 91e3efee
...@@ -140,17 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -140,17 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
} }
// Perform evaluations in parallel // Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size()); std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) { par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{}; auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front(); auto prev = const_instructions[i]->inputs().front();
std::vector<shape> prev_shape = {prev->get_shape()}; // compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()}; const std::vector<argument>& prev_eval = {prev->eval()};
auto co_shape = make_compute_output_shape(pack(c, prev_shape, prev_eval)); // prev_eval should not be used in make_compute_output_shape() as computed_shape is static
literals[i] = c.compute(co_shape, {prev->eval()}); auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
}); });
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++) for(size_t i = 0; i < const_instructions.size(); i++)
{ {
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment