Commit fa4c8e15 authored by Paul's avatar Paul
Browse files

Update dce to no remove ops with no output

parent 52f44f87
...@@ -15,8 +15,12 @@ void dead_code_elimination::apply(program& p) const ...@@ -15,8 +15,12 @@ void dead_code_elimination::apply(program& p) const
// instruction // instruction
if(ins == p.begin()) if(ins == p.begin())
continue; continue;
const auto i = std::prev(ins);
// Skip instruction with empty shape as output
if(i->result.elements() == 0)
continue;
// Skip the last instruction // Skip the last instruction
if(std::prev(ins) == last) if(i == last)
break; break;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf)); assert(p.has_instruction(leaf));
...@@ -28,7 +32,7 @@ void dead_code_elimination::apply(program& p) const ...@@ -28,7 +32,7 @@ void dead_code_elimination::apply(program& p) const
for(auto arg : args) for(auto arg : args)
self(arg); self(arg);
} }
})(std::prev(ins)); })(i);
} }
p.remove_instructions(std::next(last), p.end()); p.remove_instructions(std::next(last), p.end());
} }
......
...@@ -43,6 +43,8 @@ const std::vector<std::size_t>& shape::strides() const { return this->m_strides; ...@@ -43,6 +43,8 @@ const std::vector<std::size_t>& shape::strides() const { return this->m_strides;
std::size_t shape::elements() const std::size_t shape::elements() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->lens().empty())
return 0;
return std::accumulate( return std::accumulate(
this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>()); this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
...@@ -101,6 +103,8 @@ bool shape::standard() const { return this->m_standard; } ...@@ -101,6 +103,8 @@ bool shape::standard() const { return this->m_standard; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->lens().empty())
return 0;
return std::inner_product(this->lens().begin(), return std::inner_product(this->lens().begin(),
this->lens().end(), this->lens().end(),
this->strides().begin(), this->strides().begin(),
......
...@@ -10,10 +10,14 @@ namespace migraph { ...@@ -10,10 +10,14 @@ namespace migraph {
bool is_reshaper(const std::string& name) bool is_reshaper(const std::string& name)
{ {
static const std::unordered_set<std::string> names = {"reshape", // clang-format off
"transpose", static const std::unordered_set<std::string> names = {
// "broadcast", "reshape",
"contiguous"}; "transpose",
// "broadcast",
"contiguous"
};
// clang-format on
return contains(names, name); return contains(names, name);
} }
......
...@@ -19,8 +19,8 @@ std::vector<pass> target::get_passes(migraph::context&) const ...@@ -19,8 +19,8 @@ std::vector<pass> target::get_passes(migraph::context&) const
simplify_reshapes{}, simplify_reshapes{},
lowering{}, lowering{},
write_literals{}, write_literals{},
dead_code_elimination{}, check_context<context>{},
check_context<context>{} dead_code_elimination{}
}; };
// clang-format on // clang-format on
} }
......
...@@ -27,6 +27,22 @@ void simple_test() ...@@ -27,6 +27,22 @@ void simple_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void simple_test_nop()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(nop{});
p.add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
void duplicate_test1() void duplicate_test1()
{ {
migraph::program p; migraph::program p;
...@@ -82,6 +98,7 @@ void depth_test() ...@@ -82,6 +98,7 @@ void depth_test()
int main() int main()
{ {
simple_test(); simple_test();
simple_test_nop();
duplicate_test1(); duplicate_test1();
duplicate_test2(); duplicate_test2();
depth_test(); depth_test();
......
...@@ -80,3 +80,18 @@ struct pass_op ...@@ -80,3 +80,18 @@ struct pass_op
return inputs.front(); return inputs.front();
} }
}; };
struct nop
{
std::string name() const { return "nop"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument>) const
{
return {};
}
migraph::shape compute_shape(std::vector<migraph::shape>) const
{
return {};
}
};
...@@ -5,6 +5,13 @@ ...@@ -5,6 +5,13 @@
#include <numeric> #include <numeric>
#include "test.hpp" #include "test.hpp"
void test_shape_default()
{
migraph::shape s{};
EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0);
}
void test_shape_assign() void test_shape_assign()
{ {
migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}};
...@@ -49,7 +56,7 @@ void test_shape_broadcasted() ...@@ -49,7 +56,7 @@ void test_shape_broadcasted()
EXPECT(s.broadcasted()); EXPECT(s.broadcasted());
} }
void test_shape_default() void test_shape_default_copy()
{ {
migraph::shape s1{}; migraph::shape s1{};
migraph::shape s2{}; migraph::shape s2{};
...@@ -136,12 +143,13 @@ void test_shape4_nonpacked() ...@@ -136,12 +143,13 @@ void test_shape4_nonpacked()
int main() int main()
{ {
test_shape_default();
test_shape_assign(); test_shape_assign();
test_shape_packed_default(); test_shape_packed_default();
test_shape_packed(); test_shape_packed();
test_shape_transposed(); test_shape_transposed();
test_shape_broadcasted(); test_shape_broadcasted();
test_shape_default(); test_shape_default_copy();
test_shape4(); test_shape4();
test_shape4_nonpacked(); test_shape4_nonpacked();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment