Commit 4d449dcb authored by Paul's avatar Paul
Browse files

Check parameter shape

parent c50e8004
......@@ -353,7 +353,10 @@ argument generic_eval(const program& p,
any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name);
auto param = params.at(param_name);
if (param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + "} for parameter: " + param_name);
return param;
}));
}
else if(ins->name() == "@outline")
......
......@@ -128,7 +128,7 @@ TEST_CASE(print_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
......@@ -142,8 +142,8 @@ TEST_CASE(param_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
auto result = p.eval(
......@@ -156,8 +156,8 @@ TEST_CASE(param_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
......@@ -167,6 +167,22 @@ TEST_CASE(param_error_test)
"Parameter not found: y"));
}
TEST_CASE(param_shape_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 2}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 2}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval(
{{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
},
"Incorrect shape"));
}
TEST_CASE(replace_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment