Commit f76e2490 authored by Paul's avatar Paul
Browse files

Add tests for fp16

parent 96e74d6e
...@@ -437,7 +437,7 @@ struct relu_op ...@@ -437,7 +437,7 @@ struct relu_op
std::string name() const { return "cpu::relu"; } std::string name() const { return "cpu::relu"; }
auto fcn() const auto fcn() const
{ {
return [](auto x) { return x > 0 ? x : 0; }; return [](auto x) { return std::max(decltype(x){0}, x); };
} }
}; };
......
...@@ -41,6 +41,8 @@ inline tensor_descriptor make_tensor(const migraph::shape& s) ...@@ -41,6 +41,8 @@ inline tensor_descriptor make_tensor(const migraph::shape& s)
miopenDataType_t d; miopenDataType_t d;
if(s.type() == shape::float_type) if(s.type() == shape::float_type)
d = miopenFloat; d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else else
MIGRAPH_THROW("Unsupported type"); MIGRAPH_THROW("Unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
......
...@@ -134,7 +134,8 @@ void verify_program() ...@@ -134,7 +134,8 @@ void verify_program()
migraph::program gpu_prog; migraph::program gpu_prog;
auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); }); auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); });
auto gpu_arg = run_gpu<V>(gpu_prog); auto gpu_arg = run_gpu<V>(gpu_prog);
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg_f.get(), gpu_arg); auto cpu_arg = cpu_arg_f.get();
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg, gpu_arg);
if(not passed) if(not passed)
{ {
V v; V v;
...@@ -175,6 +176,19 @@ struct test_add ...@@ -175,6 +176,19 @@ struct test_add
} }
}; };
struct test_add_half
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::half_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
p.add_instruction(migraph::op::add{}, x, y);
return p;
}
};
struct test_mul struct test_mul
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -383,6 +397,20 @@ struct test_conv_relu ...@@ -383,6 +397,20 @@ struct test_conv_relu
} }
}; };
struct test_conv_relu_half
{
migraph::program create_program() const
{
migraph::program p;
auto input = p.add_parameter("x", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}});
auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::op::activation{"relu"}, conv);
return p;
}
};
struct test_add_relu struct test_add_relu
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -680,6 +708,7 @@ int main() ...@@ -680,6 +708,7 @@ int main()
verify_program<test_concat>(); verify_program<test_concat>();
verify_program<test_concat2>(); verify_program<test_concat2>();
verify_program<test_add>(); verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>(); verify_program<test_mul>();
verify_program<test_scale>(); verify_program<test_scale>();
verify_program<test_triadd>(); verify_program<test_triadd>();
...@@ -695,6 +724,7 @@ int main() ...@@ -695,6 +724,7 @@ int main()
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv2>(); verify_program<test_conv2>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_leaky_relu>(); verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment