Unverified Commit 7f8731a2 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add get_random_beta to dlib::rand (#2499)



* Add get_random_beta

* simplify beta distribution generation

* add tests

* fix condition and use full names for shape parameters

* remove unneeded include

* match test types to prevent Windows from failing tests

* Revert "remove unneeded include"

This reverts commit 35f55a96e91f4c4340a0bda344c494f1ce68d521.

* Revert "match test types to prevent Windows from failing tests"

This reverts commit a990307066e99cb56e15ae5de63e67e29cc72911.

* Revert "Revert "remove unneeded include""

This reverts commit 59be002c9e8b8a6ac395d87e6bf9d4b57af19d28.

* fix outdated message in DLIB_CASSERT

* relax mpc condition again

* Revert "relax mpc condition again"

This reverts commit 1d208c5dcf1ddc15e09e68e1e2e11145506729ae.

* Loop while both alpha and beta are zero

* Update dlib/rand/rand_kernel_abstract.h
Co-authored-by: default avatarDavis E. King <davis685@gmail.com>
parent 48f11679
...@@ -289,7 +289,24 @@ namespace dlib ...@@ -289,7 +289,24 @@ namespace dlib
u = get_random_double(); u = get_random_double();
return gamma + lambda*std::pow(-std::log(u), 1.0 / k); return gamma + lambda*std::pow(-std::log(u), 1.0 / k);
} }
double get_random_beta (
double alpha,
double beta
)
{
DLIB_CASSERT(alpha > 0, "alpha must be greater than zero")
DLIB_CASSERT(beta > 0, "beta must be greater than zero");
auto u = std::pow(get_random_double(), 1 / alpha);
auto v = std::pow(get_random_double(), 1 / beta);
while ((u + v) > 1 || (u == 0 && v == 0))
{
u = std::pow(get_random_double(), 1 / alpha);
v = std::pow(get_random_double(), 1 / beta);
}
return u / (u + v);
}
void swap ( void swap (
rand& item rand& item
) )
......
...@@ -209,7 +209,20 @@ namespace dlib ...@@ -209,7 +209,20 @@ namespace dlib
with shape parameter k, scale parameter lambda and with shape parameter k, scale parameter lambda and
threshold parameter gamma. threshold parameter gamma.
!*/ !*/
double get_random_beta (
double alpha,
double beta,
)
/*!
requires
- alpha > 0
- beta > 0
ensures
- returns a random number sampled from a Beta distribution
with shape parameters alpha and beta.
!*/
void swap ( void swap (
rand& item rand& item
); );
......
...@@ -412,7 +412,7 @@ namespace ...@@ -412,7 +412,7 @@ namespace
{ {
print_spinner(); print_spinner();
dlib::rand rnd(0); dlib::rand rnd(0);
const size_t N = 1024*1024*4; const size_t N = 1024*1024*4;
const double tol = 0.01; const double tol = 0.01;
double k=1.0, lambda=2.0, g=6.0; double k=1.0, lambda=2.0, g=6.0;
...@@ -426,26 +426,46 @@ namespace ...@@ -426,26 +426,46 @@ namespace
DLIB_TEST(std::abs(stats.mean() - expected_mean) < tol); DLIB_TEST(std::abs(stats.mean() - expected_mean) < tol);
DLIB_TEST(std::abs(stats.variance() - expected_var) < tol); DLIB_TEST(std::abs(stats.variance() - expected_var) < tol);
} }
void test_exponential_distribution() void test_exponential_distribution()
{ {
print_spinner(); print_spinner();
dlib::rand rnd(0); dlib::rand rnd(0);
const size_t N = 1024*1024*5; const size_t N = 1024*1024*5;
const double lambda = 1.5; const double lambda = 1.5;
print_spinner(); print_spinner();
dlib::running_stats<double> stats; dlib::running_stats<double> stats;
for (size_t i = 0; i < N; i++) for (size_t i = 0; i < N; i++)
stats.add(rnd.get_random_exponential(lambda)); stats.add(rnd.get_random_exponential(lambda));
DLIB_TEST(std::abs(stats.mean() - 1.0 / lambda) < 0.001); DLIB_TEST(std::abs(stats.mean() - 1.0 / lambda) < 0.001);
DLIB_TEST(std::abs(stats.variance() - 1.0 / (lambda*lambda)) < 0.001); DLIB_TEST(std::abs(stats.variance() - 1.0 / (lambda*lambda)) < 0.001);
DLIB_TEST(std::abs(stats.skewness() - 2.0) < 0.01); DLIB_TEST(std::abs(stats.skewness() - 2.0) < 0.01);
DLIB_TEST(std::abs(stats.ex_kurtosis() - 6.0) < 0.1); DLIB_TEST(std::abs(stats.ex_kurtosis() - 6.0) < 0.1);
} }
void test_beta_distribution()
{
print_spinner();
dlib::rand rnd(0);
const size_t N = 1024*1024*5;
const double a = 0.2;
const double b = 1.5;
running_stats<double> stats;
for (size_t i = 0; i < N; i++)
stats.add(rnd.get_random_beta(a, b));
const double expected_mean = a / (a + b);
const double expected_var = a * b / (std::pow(a + b, 2) * (a + b + 1));
DLIB_TEST(std::abs(stats.mean() - expected_mean) < 1e-5);
DLIB_TEST(std::abs(stats.variance() - expected_var) < 1e-5);
}
void outputs_are_not_changed() void outputs_are_not_changed()
{ {
// dlib::rand has been around a really long time and it is a near certainty that there is // dlib::rand has been around a really long time and it is a near certainty that there is
...@@ -501,7 +521,7 @@ namespace ...@@ -501,7 +521,7 @@ namespace
} }
} }
class rand_tester : public tester class rand_tester : public tester
{ {
public: public:
...@@ -527,6 +547,7 @@ namespace ...@@ -527,6 +547,7 @@ namespace
test_get_integer(); test_get_integer();
test_weibull_distribution(); test_weibull_distribution();
test_exponential_distribution(); test_exponential_distribution();
test_beta_distribution();
} }
} a; } a;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment