Unverified Commit 95271cfe authored by Robert Underwood's avatar Robert Underwood Committed by GitHub
Browse files

Early termination for find_{min,max}_global (#2281)



* Early termination for find_{min,max}_global

This patch adds a callback to allow the user to request cancellation of a
search using find_{min,max}_global.  This enables users to cancel
searches when they are no-longer relevent, or when the user has some
special knowledge of the solution that they can use the stop the search
early.

closes  #2250

* Moved default stopping condition into find_max_global.h since that's the code it relates to and did some minor cleanup.
Co-authored-by: default avatarDavis King <davis@dlib.net>
parent 600e0365
......@@ -10,6 +10,7 @@
#include <chrono>
#include <memory>
#include <thread>
#include <functional>
#include "../threads/thread_pool_extension.h"
#include "../statistics/statistics.h"
#include "../enable_if.h"
......@@ -114,6 +115,8 @@ template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> de
// ----------------------------------------------------------------------------------------
const auto FOREVER = std::chrono::hours(24*365*290); // 290 years
using stop_condition = std::function<bool(double)>;
const stop_condition never_stop_early = [](double) { return false; };
// ----------------------------------------------------------------------------------------
......@@ -130,7 +133,8 @@ template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> de
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
std::vector<std::vector<function_evaluation>> initial_function_evals = {}
std::vector<std::vector<function_evaluation>> initial_function_evals = {},
stop_condition should_stop = never_stop_early
)
{
// Decide which parameters should be searched on a log scale. Basically, it's
......@@ -176,17 +180,19 @@ template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> de
using namespace std::chrono;
const auto time_to_stop = steady_clock::now() + max_runtime;
//atomic<bool> doesn't support .fetch_or, use std::atomic<int> instead
std::atomic<int> this_should_stop{false};
double max_solver_overhead_time = 0;
// Now run the main solver loop.
for (size_t i = 0; i < num.max_calls && steady_clock::now() < time_to_stop; ++i)
for (size_t i = 0; i < num.max_calls && steady_clock::now() < time_to_stop && !this_should_stop.load(); ++i)
{
const auto get_next_x_start_time = steady_clock::now();
auto next = std::make_shared<function_evaluation_request>(opt.get_next_x());
const auto get_next_x_runtime = steady_clock::now() - get_next_x_start_time;
auto execute_call = [&functions,&ymult,&log_scale,&eval_time_mutex,&objective_funct_eval_time,next]() {
auto execute_call = [&functions,&ymult,&log_scale,&eval_time_mutex,&objective_funct_eval_time,next,&should_stop,&this_should_stop]() {
matrix<double,0,1> x = next->x();
// Undo any log-scaling that was applied to the variables before we pass them
// to the functions being optimized.
......@@ -198,6 +204,7 @@ template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> de
const auto funct_eval_start = steady_clock::now();
double y = ymult*call_function_and_expand_args(functions[next->function_idx()], x);
const double funct_eval_runtime = duration_cast<nanoseconds>(steady_clock::now() - funct_eval_start).count();
this_should_stop.fetch_or(should_stop(y*ymult));
next->set(y);
std::lock_guard<std::mutex> lock(eval_time_mutex);
......@@ -364,7 +371,8 @@ template <typename T> static auto go(T&& f, const matrix<double, 0, 1>& a) -> de
// find_max_global() instances below and turn them into the argument types expected by
// find_max_global() above.
template <typename T>
const T& normalize(const T& item) {
const T& normalize(const T& item)
{
return item;
}
......
......@@ -10,6 +10,7 @@
#include "../threads/thread_pool_extension_abstract.h"
#include <utility>
#include <chrono>
#include <functional>
namespace dlib
{
......@@ -69,6 +70,19 @@ namespace dlib
const auto FOREVER = std::chrono::hours(24*356*290); // 290 years, basically forever
/*!
WHAT THIS OBJECT REPRESENTS
A call-back that returns true when the search should stop.
It is useful when the user either wants to terminate the search based on special knowledge
of the function, the user's preferences regarding what is a "good-enough" solution, or
based on the results of some external process which may have found a solution and this search
is no longer required.
!*/
using stop_condition = std::function<bool(double)>;
// The default condition.
const stop_condition never_stop_early = [](double) { return false; };
// ----------------------------------------------------------------------------------------
template <
......@@ -81,7 +95,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {}
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
requires
......@@ -134,6 +149,7 @@ namespace dlib
- find_max_global() runs until one of the following is true:
- The total number of calls to the provided functions is == num.max_calls
- More than max_runtime time has elapsed since the start of this function.
- should_stop(f(x)) returns true
- Any variables that satisfy the following conditions are optimized on a log-scale:
- The lower bound on the variable is > 0
- The ratio of the upper bound to lower bound is >= 1000
......@@ -166,7 +182,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {}
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
this function is identical to the find_max_global() defined immediately above,
......@@ -182,7 +199,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {}
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
This function is identical to the find_max_global() defined immediately above,
......@@ -199,7 +217,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {}
const std::vector<std::vector<function_evaluation>>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
This function is identical to the find_max_global() defined immediately above,
......@@ -222,7 +241,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<function_evaluation>& initial_function_evals = {}
const std::vector<function_evaluation>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
requires
......@@ -270,6 +290,7 @@ namespace dlib
- find_max_global() runs until one of the following is true:
- The total number of calls to f() is == num.max_calls
- More than max_runtime time has elapsed since the start of this function.
- should_stop(f(x)) returns true
- Any variables that satisfy the following conditions are optimized on a log-scale:
- The lower bound on the variable is > 0
- The ratio of the upper bound to lower bound is >= 1000
......@@ -305,7 +326,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<function_evaluation>& initial_function_evals = {}
const std::vector<function_evaluation>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
This function is identical to the find_max_global() defined immediately above,
......@@ -323,7 +345,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<function_evaluation>& initial_function_evals = {}
const std::vector<function_evaluation>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
This function is identical to the find_max_global() defined immediately above,
......@@ -342,7 +365,8 @@ namespace dlib
const max_function_calls num,
const std::chrono::nanoseconds max_runtime = FOREVER,
double solver_epsilon = 0,
const std::vector<function_evaluation>& initial_function_evals = {}
const std::vector<function_evaluation>& initial_function_evals = {},
stop_condition should_stop = never_stop_early
);
/*!
This function is identical to the find_min_global() defined immediately above,
......
......@@ -220,10 +220,11 @@ namespace
DLIB_TEST(std::abs(result.x - 2) < 1e-9);
print_spinner();
result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, max_function_calls(10));
dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x);
DLIB_TEST(result.x.size()==1);
DLIB_TEST(std::abs(result.x - 1) < 1e-9);
unsigned int normal_evals=0, early_evals=0;
auto normal_result = find_max_global([&normal_evals](double x){ normal_evals++; return -std::pow(x-2,2.0); }, -10, 1, max_function_calls(10));
dlog << LINFO << "(x-2)^2, bound at 1: " << trans(normal_result.x);
DLIB_TEST(normal_result.x.size()==1);
DLIB_TEST(std::abs(normal_result.x - 1) < 1e-9);
print_spinner();
result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, std::chrono::seconds(2));
......@@ -232,6 +233,15 @@ namespace
DLIB_TEST(std::abs(result.x - 1) < 1e-9);
print_spinner();
constexpr auto close_enough = -16.0;
auto early_result = find_max_global([&early_evals](double x){ early_evals++; return -std::pow(x-2,2.0); }, -10, 1, max_function_calls(10), 0.0, std::vector<function_evaluation>{}, [](double y){ return (y >= close_enough);});
dlog << LINFO << "(x-2)^2, bound at 1: " << trans(early_result.x);
DLIB_TEST(early_result.x.size()==1);
DLIB_TEST(std::abs(early_result.y) <= std::abs(close_enough));
DLIB_TEST(std::abs(early_result.x - 1) <= 4);
DLIB_TEST(normal_evals >= early_evals);
print_spinner();
result = find_max_global([](double a, double b){ return -complex_holder_table(a,b);},
{-10, -10}, {10, 10}, max_function_calls(400), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment