Commit f76e2490 authored by Paul's avatar Paul
Browse files

Add tests for fp16

parent 96e74d6e
......@@ -437,7 +437,7 @@ struct relu_op
std::string name() const { return "cpu::relu"; }
auto fcn() const
{
return [](auto x) { return x > 0 ? x : 0; };
return [](auto x) { return std::max(decltype(x){0}, x); };
}
};
......
......@@ -41,6 +41,8 @@ inline tensor_descriptor make_tensor(const migraph::shape& s)
miopenDataType_t d;
if(s.type() == shape::float_type)
d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else
MIGRAPH_THROW("Unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
......
......@@ -134,7 +134,8 @@ void verify_program()
migraph::program gpu_prog;
auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); });
auto gpu_arg = run_gpu<V>(gpu_prog);
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg_f.get(), gpu_arg);
auto cpu_arg = cpu_arg_f.get();
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg, gpu_arg);
if(not passed)
{
V v;
......@@ -175,6 +176,19 @@ struct test_add
}
};
struct test_add_half
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::half_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
p.add_instruction(migraph::op::add{}, x, y);
return p;
}
};
struct test_mul
{
migraph::program create_program() const
......@@ -383,6 +397,20 @@ struct test_conv_relu
}
};
struct test_conv_relu_half
{
migraph::program create_program() const
{
migraph::program p;
auto input = p.add_parameter("x", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}});
auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::op::activation{"relu"}, conv);
return p;
}
};
struct test_add_relu
{
migraph::program create_program() const
......@@ -680,6 +708,7 @@ int main()
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_scale>();
verify_program<test_triadd>();
......@@ -695,6 +724,7 @@ int main()
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment