Unverified Commit 7f8731a2 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add get_random_beta to dlib::rand (#2499)



* Add get_random_beta

* simplify beta distribution generation

* add tests

* fix condition and use full names for shape parameters

* remove unneeded include

* match test types to prevent Windows from failing tests

* Revert "remove unneeded include"

This reverts commit 35f55a96e91f4c4340a0bda344c494f1ce68d521.

* Revert "match test types to prevent Windows from failing tests"

This reverts commit a990307066e99cb56e15ae5de63e67e29cc72911.

* Revert "Revert "remove unneeded include""

This reverts commit 59be002c9e8b8a6ac395d87e6bf9d4b57af19d28.

* fix outdated message in DLIB_CASSERT

* relax mpc condition again

* Revert "relax mpc condition again"

This reverts commit 1d208c5dcf1ddc15e09e68e1e2e11145506729ae.

* Loop while both alpha and beta are zero

* Update dlib/rand/rand_kernel_abstract.h
Co-authored-by: default avatarDavis E. King <davis685@gmail.com>
parent 48f11679
......@@ -290,6 +290,23 @@ namespace dlib
return gamma + lambda*std::pow(-std::log(u), 1.0 / k);
}
double get_random_beta (
double alpha,
double beta
)
{
DLIB_CASSERT(alpha > 0, "alpha must be greater than zero")
DLIB_CASSERT(beta > 0, "beta must be greater than zero");
auto u = std::pow(get_random_double(), 1 / alpha);
auto v = std::pow(get_random_double(), 1 / beta);
while ((u + v) > 1 || (u == 0 && v == 0))
{
u = std::pow(get_random_double(), 1 / alpha);
v = std::pow(get_random_double(), 1 / beta);
}
return u / (u + v);
}
void swap (
rand& item
)
......
......@@ -210,6 +210,19 @@ namespace dlib
threshold parameter gamma.
!*/
double get_random_beta (
double alpha,
double beta,
)
/*!
requires
- alpha > 0
- beta > 0
ensures
- returns a random number sampled from a Beta distribution
with shape parameters alpha and beta.
!*/
void swap (
rand& item
);
......
......@@ -446,6 +446,26 @@ namespace
DLIB_TEST(std::abs(stats.ex_kurtosis() - 6.0) < 0.1);
}
void test_beta_distribution()
{
print_spinner();
dlib::rand rnd(0);
const size_t N = 1024*1024*5;
const double a = 0.2;
const double b = 1.5;
running_stats<double> stats;
for (size_t i = 0; i < N; i++)
stats.add(rnd.get_random_beta(a, b));
const double expected_mean = a / (a + b);
const double expected_var = a * b / (std::pow(a + b, 2) * (a + b + 1));
DLIB_TEST(std::abs(stats.mean() - expected_mean) < 1e-5);
DLIB_TEST(std::abs(stats.variance() - expected_var) < 1e-5);
}
void outputs_are_not_changed()
{
// dlib::rand has been around a really long time and it is a near certainty that there is
......@@ -527,6 +547,7 @@ namespace
test_get_integer();
test_weibull_distribution();
test_exponential_distribution();
test_beta_distribution();
}
} a;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment