"vscode:/vscode.git/clone" did not exist on "9d3f634a3cc37599db9f3ebfeb1f9a3c45d9a673"
Commit b47d3a85 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into includes

parents 40d66cf8 f67fc560
...@@ -1133,7 +1133,15 @@ struct onnx_parser ...@@ -1133,7 +1133,15 @@ struct onnx_parser
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()}; {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()};
}
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
......
...@@ -349,13 +349,17 @@ argument generic_eval(const program& p, ...@@ -349,13 +349,17 @@ argument generic_eval(const program& p,
} }
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(
auto param_name = ins, trace(ins, [&] {
any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name); auto param = params.at(param_name);
})); if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
return param;
}));
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp>
float sigmoid(float x) { return 1 / (1 + expf(-x)); } float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test) ...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(fp16_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {1}};
migraphx::half a{1.5};
migraphx::half b{2.5};
migraphx::half c{4.0};
auto l0 = p.add_literal(migraphx::literal{s, {a}});
auto l1 = p.add_literal(migraphx::literal{s, {b}});
p.add_instruction(migraphx::op::add{}, l0, l1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -128,7 +128,7 @@ TEST_CASE(print_test) ...@@ -128,7 +128,7 @@ TEST_CASE(print_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two); p.add_instruction(sum_op{}, x, two);
...@@ -142,8 +142,8 @@ TEST_CASE(param_test) ...@@ -142,8 +142,8 @@ TEST_CASE(param_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval( auto result = p.eval(
...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test) ...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test) ...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test)
"Parameter not found: y")); "Parameter not found: y"));
} }
TEST_CASE(param_shape_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 2}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 2}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}});
},
"Incorrect shape"));
}
TEST_CASE(replace_test) TEST_CASE(replace_test)
{ {
migraphx::program p; migraphx::program p;
......
add-fp16-example:m

0
12"Add test-add-fp16* 
*|B0* 
*B1Z
0


Z
1


b
2


B
\ No newline at end of file
...@@ -1186,7 +1186,9 @@ TEST_CASE(group_conv_test) ...@@ -1186,7 +1186,9 @@ TEST_CASE(group_conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.group = 4; op.group = 4;
p.add_instruction(op, l0, l1); p.add_instruction(op, l0, l1);
migraphx::parse_onnx("group_conv_test.onnx"); auto prog = migraphx::parse_onnx("group_conv_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(pad_test) TEST_CASE(pad_test)
...@@ -1194,13 +1196,14 @@ TEST_CASE(pad_test) ...@@ -1194,13 +1196,14 @@ TEST_CASE(pad_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0); p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
migraphx::parse_onnx("pad_test.onnx"); auto prog = migraphx::parse_onnx("pad_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(lrn_test) TEST_CASE(lrn_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op; migraphx::op::lrn op;
op.size = 5; op.size = 5;
...@@ -1208,7 +1211,22 @@ TEST_CASE(lrn_test) ...@@ -1208,7 +1211,22 @@ TEST_CASE(lrn_test)
op.beta = 0.75; op.beta = 0.75;
op.bias = 1.0; op.bias = 1.0;
p.add_instruction(op, l0); p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx"); auto prog = migraphx::parse_onnx("lrn_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_fp16_test)
{
migraphx::program p;
auto l0 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("add_fp16_test.onnx");
EXPECT(p == prog);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment