Commit 24cc6ea8 authored by Paul's avatar Paul
Browse files

Fix bn tests

parent e01c70e6
......@@ -2,7 +2,7 @@
namespace migraph {
argument generate_argument(shape s, std::mt19937::result_type seed)
argument generate_argument(shape s, unsigned long seed)
{
argument result;
s.visit_type([&](auto as) {
......@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed)
return result;
}
literal generate_literal(shape s, std::mt19937::result_type seed)
literal generate_literal(shape s, unsigned long seed)
{
literal result;
s.visit_type([&](auto as) {
......@@ -24,4 +24,12 @@ literal generate_literal(shape s, std::mt19937::result_type seed)
return result;
}
// TODO: Move to literal.cpp
literal abs(literal l)
{
return transform(l, [](auto x) {
return std::fabs(x);
});
}
} // namespace migraph
......@@ -8,7 +8,7 @@
namespace migraph {
template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
T normalize(unsigned long z)
constexpr T normalize(unsigned long z)
{
if(z == 0)
return 0;
......@@ -16,7 +16,7 @@ T normalize(unsigned long z)
}
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
T normalize(unsigned long z)
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
const auto half_max = max / 2;
......@@ -24,7 +24,7 @@ T normalize(unsigned long z)
}
template <class T, MIGRAPH_REQUIRES(not std::is_signed<T>{} and std::is_integral<T>{})>
T normalize(unsigned long z)
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
return z % max;
......@@ -33,9 +33,10 @@ T normalize(unsigned long z)
template <class T>
struct xorshf96_generator
{
unsigned long seed = 0;
unsigned long x = 123456789;
unsigned long y = 362436069;
unsigned long z = 521288629;
unsigned long z = 521288629 ^ seed;
constexpr T operator()() noexcept
{
......@@ -53,16 +54,18 @@ struct xorshf96_generator
};
template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type)
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{});
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
return result;
}
argument generate_argument(shape s, std::mt19937::result_type seed = 0);
argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, std::mt19937::result_type seed = 0);
literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l);
} // namespace migraph
......
......@@ -94,6 +94,19 @@ struct literal : raw_data<literal>
}
};
template<class F>
literal transform(literal l, F f)
{
literal result;
l.visit([&](auto x) {
using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0);
std::transform(x.begin(), x.end(), output.begin(), f);
result = literal{l.get_shape(), output};
});
return result;
}
} // namespace migraph
#endif
......@@ -332,10 +332,10 @@ struct test_batchnorm_inference_2
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars);
auto variance = p.add_parameter("variance", vars);
auto scale = p.add_parameter("scale", vars);
auto bias = p.add_parameter("bias", vars);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
return p;
}
......@@ -355,10 +355,10 @@ struct test_batchnorm_inference
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars);
auto variance = p.add_parameter("variance", vars);
auto scale = p.add_parameter("scale", vars);
auto bias = p.add_parameter("bias", vars);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment