#include #include #include #include #include struct eliminate_allocation_target { std::size_t align = 32; std::string name() const { return "eliminate_allocation"; } std::vector get_passes(migraph::context&) const { return {migraph::eliminate_allocation{"allocate", align}, migraph::dead_code_elimination{}}; } migraph::context get_context() const { return {}; } }; struct allocate { migraph::shape s{}; std::string name() const { return "allocate"; } migraph::shape compute_shape(const std::vector& inputs) const { migraph::check_shapes{inputs}.has(0); return s; } migraph::argument compute(migraph::context&, const migraph::shape& output_shape, const std::vector&) const { return {output_shape}; } }; void basic() { migraph::program p; auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}}); auto p1 = p.add_instruction(pass_op{}, a1); auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {40}}}); auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); p.add_instruction(pass_op{}, a3, p2); p.compile(eliminate_allocation_target{}); EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); } void aligned() { migraph::program p; auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto p1 = p.add_instruction(pass_op{}, a1); auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); p.add_instruction(pass_op{}, a3, p2); p.compile(eliminate_allocation_target{}); EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); } void unaligned() { migraph::program p; auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto p1 = p.add_instruction(pass_op{}, a1); auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); p.add_instruction(pass_op{}, a3, p2); p.compile(eliminate_allocation_target{1}); EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); } void float_aligned() { migraph::program p; auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto p1 = p.add_instruction(pass_op{}, a1); auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto p2 = p.add_instruction(pass_op{}, a2, p1); auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); p.add_instruction(pass_op{}, a3, p2); p.compile(eliminate_allocation_target{4}); EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); } int main() { setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1); basic(); aligned(); unaligned(); float_aligned(); }