// Copyright (C) 2017 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_FiND_GLOBAL_MAXIMUM_hH_ #define DLIB_FiND_GLOBAL_MAXIMUM_hH_ #include "global_function_search.h" // TODO, move ct_make_integer_range into some other file so we don't have to include the // dnn header. That thing is huge. #include #include namespace dlib { namespace gopt_impl { // ---------------------------------------------------------------------------------------- class disable_decay_to_scalar { const matrix& a; public: disable_decay_to_scalar(const matrix& a) : a(a){} operator const matrix&() const { return a;} }; template auto _cwv ( T&& f, const matrix& a, impl::ct_integers_list ) -> decltype(f(a(indices-1)...)) { DLIB_CASSERT(a.size() == sizeof...(indices), "You invoked dlib::call_with_vect(f,a) but the number of arguments expected by f() doesn't match the size of 'a'. " << "Expected " << sizeof...(indices) << " arguments but got " << a.size() << "." ); return f(a(indices-1)...); } template struct call_with_vect { template static auto go(T&& f, const matrix& a) -> decltype(_cwv(std::forward(f),a,typename impl::ct_make_integer_range::type())) { return _cwv(std::forward(f),a,typename impl::ct_make_integer_range::type()); } template static auto go(T&& f, const matrix& a) -> decltype(call_with_vect::template go(std::forward(f),a)) { return call_with_vect::go(std::forward(f),a); } }; template <> struct call_with_vect<0> { template static auto go(T&& f, const matrix& a) -> decltype(f(disable_decay_to_scalar(a))) { return f(disable_decay_to_scalar(a)); } }; } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template auto call_with_vect( T&& f, const matrix& a ) -> decltype(gopt_impl::call_with_vect<40>::go(f,a)) { // unpack up to 40 parameters when calling f() return gopt_impl::call_with_vect<40>::go(std::forward(f),a); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- struct max_function_calls { max_function_calls() = default; explicit max_function_calls(size_t max_calls) : max_calls(max_calls) {} size_t max_calls = std::numeric_limits::max(); }; // ---------------------------------------------------------------------------------------- template < typename funct > std::pair find_global_maximum ( std::vector& functions, const std::vector& specs, const max_function_calls num, const std::chrono::nanoseconds max_runtime, double solver_epsilon = 1e-11 ) { global_function_search opt(specs); opt.set_solver_epsilon(solver_epsilon); const auto time_to_stop = std::chrono::steady_clock::now() + max_runtime; for (size_t i = 0; i < num.max_calls && std::chrono::steady_clock::now() < time_to_stop; ++i) { auto next = opt.get_next_x(); double y = call_with_vect(functions[next.function_idx()], next.x()); next.set(y); // TODO, remove this funky test code matrix x; size_t function_idx; opt.get_best_function_eval(x,y,function_idx); using namespace std; cout << "\ni: "<< i << endl; cout << "best eval x: "<< trans(x); cout << "best eval y: "<< y << endl; cout << "best eval function index: "<< function_idx << endl; if (std::abs(y - 21.9210397) < 0.0001) { cout << "DONE!" << endl; //cin.get(); break; } } matrix x; double y; size_t function_idx; opt.get_best_function_eval(x,y,function_idx); return std::make_pair(function_idx, function_evaluation(x,std::move(y))); } // ---------------------------------------------------------------------------------------- template < typename funct > function_evaluation find_global_maximum ( funct f, const matrix& lower, const matrix& upper, const max_function_calls num, double solver_epsilon = 1e-11 ) { std::vector functions(1,f); std::vector specs(1, function_spec(lower, upper)); auto forever = std::chrono::hours(24*356*290); return find_global_maximum(functions, specs, num, forever, solver_epsilon).second; } template < typename funct > function_evaluation find_global_maximum ( funct f, const double lower, const double upper, const max_function_calls num, double solver_epsilon = 1e-11 ) { return find_global_maximum(f, matrix({lower}), matrix({upper}), num, solver_epsilon); } template < typename funct > function_evaluation find_global_maximum ( funct f, const matrix& lower, const matrix& upper, const std::vector& is_integer_variable, const max_function_calls num, double solver_epsilon = 1e-11 ) { std::vector functions(1, std::move(f)); std::vector specs(1, function_spec(lower, upper, is_integer_variable)); auto forever = std::chrono::hours(24*356*290); return find_global_maximum(functions, specs, num, forever, solver_epsilon).second; } // ---------------------------------------------------------------------------------------- template < typename funct > function_evaluation find_global_maximum ( funct f, const matrix& lower, const matrix& upper, const std::chrono::nanoseconds max_runtime, double solver_epsilon = 1e-11 ) { std::vector functions(1,f); std::vector specs(1, function_spec(lower, upper)); return find_global_maximum(functions, specs, max_function_calls(), max_runtime, solver_epsilon).second; } template < typename funct > function_evaluation find_global_maximum ( funct f, const double lower, const double upper, const std::chrono::nanoseconds max_runtime, double solver_epsilon = 1e-11 ) { return find_global_maximum(f, matrix({lower}), matrix({upper}), max_runtime, solver_epsilon); } template < typename funct > function_evaluation find_global_maximum ( funct f, const matrix& lower, const matrix& upper, const std::vector& is_integer_variable, const std::chrono::nanoseconds max_runtime, double solver_epsilon = 1e-11 ) { std::vector functions(1, std::move(f)); std::vector specs(1, function_spec(lower, upper, is_integer_variable)); return find_global_maximum(functions, specs, max_function_calls(), max_runtime, solver_epsilon).second; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_FiND_GLOBAL_MAXIMUM_hH_