Commit 26c0e6d2 authored by Paul's avatar Paul
Browse files

Add a test for broadcast

parent 88b081e5
......@@ -3,12 +3,12 @@
#include <basic_ops.hpp>
#include <test.hpp>
struct contigous_target
struct contiguous_target
{
std::string name() const { return "contigous"; }
std::string name() const { return "contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::auto_contigous{}};
return {migraph::auto_contiguous{}};
}
migraph::context get_context() const { return {}; }
};
......@@ -18,18 +18,43 @@ migraph::literal get_2x2()
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2()
{
return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}};
}
void after_literal_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.add_instruction(migraph::transpose{{1, 0}}, l);
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(contigous_target{});
p.compile(contiguous_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
}
int main() { after_literal_transpose(); }
void after_literal_broadcast()
{
migraph::program p;
auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::broadcast{}, l1, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
p.compile(contiguous_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
}
int main() {
after_literal_transpose();
after_literal_broadcast();
}
......@@ -61,3 +61,22 @@ struct minus_op
return inputs.front();
}
};
struct pass_op
{
std::string name() const { return "pass"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment