Commit 24cc6ea8 authored by Paul's avatar Paul
Browse files

Fix bn tests

parent e01c70e6
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
namespace migraph { namespace migraph {
argument generate_argument(shape s, std::mt19937::result_type seed) argument generate_argument(shape s, unsigned long seed)
{ {
argument result; argument result;
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
...@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed) ...@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed)
return result; return result;
} }
literal generate_literal(shape s, std::mt19937::result_type seed) literal generate_literal(shape s, unsigned long seed)
{ {
literal result; literal result;
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
...@@ -24,4 +24,12 @@ literal generate_literal(shape s, std::mt19937::result_type seed) ...@@ -24,4 +24,12 @@ literal generate_literal(shape s, std::mt19937::result_type seed)
return result; return result;
} }
// TODO: Move to literal.cpp
literal abs(literal l)
{
return transform(l, [](auto x) {
return std::fabs(x);
});
}
} // namespace migraph } // namespace migraph
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraph { namespace migraph {
template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
...@@ -16,7 +16,7 @@ T normalize(unsigned long z) ...@@ -16,7 +16,7 @@ T normalize(unsigned long z)
} }
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
const auto half_max = max / 2; const auto half_max = max / 2;
...@@ -24,7 +24,7 @@ T normalize(unsigned long z) ...@@ -24,7 +24,7 @@ T normalize(unsigned long z)
} }
template <class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})> template <class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})>
T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
return z % max; return z % max;
...@@ -33,9 +33,10 @@ T normalize(unsigned long z) ...@@ -33,9 +33,10 @@ T normalize(unsigned long z)
template <class T> template <class T>
struct xorshf96_generator struct xorshf96_generator
{ {
unsigned long seed = 0;
unsigned long x = 123456789; unsigned long x = 123456789;
unsigned long y = 362436069; unsigned long y = 362436069;
unsigned long z = 521288629; unsigned long z = 521288629 ^ seed;
constexpr T operator()() noexcept constexpr T operator()() noexcept
{ {
...@@ -53,16 +54,18 @@ struct xorshf96_generator ...@@ -53,16 +54,18 @@ struct xorshf96_generator
}; };
template <class T> template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type) std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
return result; return result;
} }
argument generate_argument(shape s, std::mt19937::result_type seed = 0); argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, std::mt19937::result_type seed = 0); literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l);
} // namespace migraph } // namespace migraph
......
...@@ -94,6 +94,19 @@ struct literal : raw_data<literal> ...@@ -94,6 +94,19 @@ struct literal : raw_data<literal>
} }
}; };
template<class F>
literal transform(literal l, F f)
{
literal result;
l.visit([&](auto x) {
using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0);
std::transform(x.begin(), x.end(), output.begin(), f);
result = literal{l.get_shape(), output};
});
return result;
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -332,10 +332,10 @@ struct test_batchnorm_inference_2 ...@@ -332,10 +332,10 @@ struct test_batchnorm_inference_2
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_parameter("variance", vars); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_parameter("scale", vars); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_parameter("bias", vars); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
return p; return p;
} }
...@@ -355,10 +355,10 @@ struct test_batchnorm_inference ...@@ -355,10 +355,10 @@ struct test_batchnorm_inference
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_parameter("variance", vars); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_parameter("scale", vars); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_parameter("bias", vars); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment