Commit df4b1e15 authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'master' into remove_concat

parents 86dfc8b9 4debaf07
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/env.hpp> #include <migraph/env.hpp>
#include <migraph/ranges.hpp>
#include <migraph/time.hpp> #include <migraph/time.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <iostream> #include <iostream>
...@@ -329,8 +330,11 @@ argument generic_eval(const program& p, ...@@ -329,8 +330,11 @@ argument generic_eval(const program& p,
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return params.at( auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter); any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPH_THROW("Parameter not found: " + param_name);
return params.at(param_name);
})); }));
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
......
...@@ -104,6 +104,21 @@ void param_test() ...@@ -104,6 +104,21 @@ void param_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void param_error_test()
{
migraph::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraph::exception>(
[&] {
p.eval({{"x", migraph::literal{1}.get_argument()}});
},
"Parameter not found: y"));
}
void replace_test() void replace_test()
{ {
migraph::program p; migraph::program p;
...@@ -215,6 +230,7 @@ int main() ...@@ -215,6 +230,7 @@ int main()
literal_test2(); literal_test2();
print_test(); print_test();
param_test(); param_test();
param_error_test();
replace_test(); replace_test();
replace_ins_test(); replace_ins_test();
replace_ins_test2(); replace_ins_test2();
......
...@@ -140,7 +140,7 @@ bool throws(F f) ...@@ -140,7 +140,7 @@ bool throws(F f)
} }
} }
template <class F, class Exception> template <class Exception, class F>
bool throws(F f, const std::string& msg = "") bool throws(F f, const std::string& msg = "")
{ {
try try
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment