Commit dbcbaeeb authored by Khalique's avatar Khalique
Browse files

fix tests, check relu+lrn on gpu

parent 6b1e4b67
......@@ -65,7 +65,7 @@ struct lrn
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size;
int size = 1;
std::string name() const { return "lrn"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -733,16 +733,16 @@ TEST_CASE(leaky_relu_test)
TEST_CASE(lrn_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 5, 1, 1}};
auto l = p.add_literal(migraph::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}});
p.add_instruction(migraph::op::lrn{0.0001, 0.75, 1, 5}, l);
p.compile(migraph::cpu::target{});
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}};
auto l = p.add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}});
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1, 5}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(imagescaler_test)
......
......@@ -637,13 +637,14 @@ struct test_elu
}
};
struct test_lrn
struct test_relu_lrn
{
migraph::program create_program() const
migraphx::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 5, 2, 2}});
p.add_instruction(migraph::op::lrn{0.0001, 0.75, 1.0, 5}, x);
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = p.add_instruction(migraphx::op::relu{}, x);
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
return p;
}
};
......@@ -1097,7 +1098,7 @@ struct test_conv_bn_relu_pooling2
int main()
{
verify_program<test_lrn>();
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>();
verify_program<test_abs>();
verify_program<test_concat>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment