Commit 71e1382d authored by Paul's avatar Paul
Browse files

Add test for replace op

parent 44194a24
......@@ -263,6 +263,31 @@ TEST_CASE(replace_ins_test2)
EXPECT(result != migraphx::literal{3});
}
TEST_CASE(replace_op_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, two, one);
sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraphx::literal{1});
EXPECT(result != migraphx::literal{3});
}
TEST_CASE(replace_op_recompute_shape_throw)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
EXPECT(test::throws([&] { sum->replace(unary_pass_op{}); }));
}
TEST_CASE(insert_replace_test)
{
migraphx::program p;
......
......@@ -82,6 +82,26 @@ struct pass_op
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct unary_pass_op
{
std::string name() const { return "unary_pass"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.size() != 1)
MIGRAPHX_THROW("Wrong inputs");
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct pass_standard_op
{
std::string name() const { return "pass"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment