Commit 26c0e6d2 authored by Paul's avatar Paul
Browse files

Add a test for broadcast

parent 88b081e5
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct contigous_target struct contiguous_target
{ {
std::string name() const { return "contigous"; } std::string name() const { return "contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraph::pass> get_passes(migraph::context&) const
{ {
return {migraph::auto_contigous{}}; return {migraph::auto_contiguous{}};
} }
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
...@@ -18,18 +18,43 @@ migraph::literal get_2x2() ...@@ -18,18 +18,43 @@ migraph::literal get_2x2()
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}}; return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
} }
migraph::literal get_2()
{
return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}};
}
void after_literal_transpose() void after_literal_transpose()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.add_instruction(migraph::transpose{{1, 0}}, l); auto t = p.add_instruction(migraph::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
p.compile(contigous_target{}); p.compile(contiguous_target{});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
} }
int main() { after_literal_transpose(); } void after_literal_broadcast()
{
migraph::program p;
auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::broadcast{}, l1, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
p.compile(contiguous_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
}
int main() {
after_literal_transpose();
after_literal_broadcast();
}
...@@ -61,3 +61,22 @@ struct minus_op ...@@ -61,3 +61,22 @@ struct minus_op
return inputs.front(); return inputs.front();
} }
}; };
struct pass_op
{
std::string name() const { return "pass"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment