"...git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "024009675ec1334791ec733eff85172fcd666ed4"
Commit fe70bd12 authored by Davis King's avatar Davis King
Browse files

Fixed spelling error in method name. Also optimized and cleaned up the

automatic step size reduction code a little.
parent 9485da18
...@@ -328,12 +328,14 @@ namespace dlib ...@@ -328,12 +328,14 @@ namespace dlib
rs.clear(); rs.clear();
} }
void set_setep_size ( void set_step_size (
double ss double ss
) )
{ {
DLIB_CASSERT(ss > 0,""); DLIB_CASSERT(ss > 0,"");
wait_for_thread_to_pause(); wait_for_thread_to_pause();
if (step_size != ss)
previous_loss_values.clear();
step_size = ss; step_size = ss;
} }
...@@ -391,24 +393,33 @@ namespace dlib ...@@ -391,24 +393,33 @@ namespace dlib
resizable_tensor t; resizable_tensor t;
}; };
template <typename T> void record_loss(double loss)
void run_update(job_t& next_job, const T&)
{ {
double loss = net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers),step_size); // Say that we will check if the gradient is bad 200 times during each
// iter_between_step_size_adjust interval of network updates. This kind of
// budgeting causes our gradient checking to use a fixed amount of
// computational resources, regardless of the size of
// iter_between_step_size_adjust.
gradient_check_budget += 200;
rs.add(loss); rs.add(loss);
previous_loss_values.push_back(loss); previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust) if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front(); previous_loss_values.pop_front();
} }
template <typename T>
void run_update(job_t& next_job, const T&)
{
double loss = net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers),step_size);
record_loss(loss);
}
void run_update(job_t& next_job, const no_label_type&) void run_update(job_t& next_job, const no_label_type&)
{ {
no_label_type pick_wich_run_update; no_label_type pick_wich_run_update;
double loss = net.update(next_job.t, make_sstack(solvers), step_size); double loss = net.update(next_job.t, make_sstack(solvers), step_size);
rs.add(loss); record_loss(loss);
previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front();
} }
void thread() try void thread() try
...@@ -425,9 +436,14 @@ namespace dlib ...@@ -425,9 +436,14 @@ namespace dlib
run_update(next_job, pick_wich_run_update); run_update(next_job, pick_wich_run_update);
// If we have been running for a while then check if the loss is still // If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size. // dropping. If it isn't then we will reduce the step size. Note that we
if (previous_loss_values.size() >= iter_between_step_size_adjust) // have a "budget" that prevents us from calling
// probability_gradient_greater_than() every iteration. We do this because
// it can be expensive to compute when previous_loss_values is large.
if (previous_loss_values.size() >= iter_between_step_size_adjust &&
gradient_check_budget > previous_loss_values.size())
{ {
gradient_check_budget = 0;
if (probability_gradient_greater_than(previous_loss_values, 0) > 0.49) if (probability_gradient_greater_than(previous_loss_values, 0) > 0.49)
{ {
step_size = step_size_shrink*step_size; step_size = step_size_shrink*step_size;
...@@ -458,12 +474,13 @@ namespace dlib ...@@ -458,12 +474,13 @@ namespace dlib
verbose = false; verbose = false;
cuda_device_id = dlib::cuda::get_device(); cuda_device_id = dlib::cuda::get_device();
step_size = 1; step_size = 1;
min_step_size = 1e-4; min_step_size = 1e-3;
iter_between_step_size_adjust = 2000; iter_between_step_size_adjust = 2000;
step_size_shrink = 0.1; step_size_shrink = 0.1;
epoch_iteration = 0; epoch_iteration = 0;
epoch_pos = 0; epoch_pos = 0;
train_one_step_calls = 0; train_one_step_calls = 0;
gradient_check_budget = 0;
start(); start();
} }
...@@ -575,7 +592,7 @@ namespace dlib ...@@ -575,7 +592,7 @@ namespace dlib
std::vector<solver_type> solvers; std::vector<solver_type> solvers;
std::atomic<double> step_size; std::atomic<double> step_size;
double min_step_size; double min_step_size;
std::atomic<long> iter_between_step_size_adjust; std::atomic<unsigned long> iter_between_step_size_adjust;
std::atomic<double> step_size_shrink; std::atomic<double> step_size_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time; std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename; std::string sync_filename;
...@@ -584,6 +601,7 @@ namespace dlib ...@@ -584,6 +601,7 @@ namespace dlib
unsigned long epoch_pos; unsigned long epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time; std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls; unsigned long long train_one_step_calls;
unsigned long gradient_check_budget;
// The job object is not logically part of the state of this object. It is here // The job object is not logically part of the state of this object. It is here
// only to avoid reallocating it over and over. // only to avoid reallocating it over and over.
......
...@@ -60,7 +60,7 @@ namespace dlib ...@@ -60,7 +60,7 @@ namespace dlib
- #get_max_num_epochs() == 10000 - #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128 - #get_mini_batch_size() == 128
- #get_step_size() == 1 - #get_step_size() == 1
- #get_min_step_size() == 1e-4 - #get_min_step_size() == 1e-3
- #get_iterations_between_step_size_adjust() == 2000 - #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1 - #get_step_size_shrink() == 0.1
!*/ !*/
...@@ -149,7 +149,7 @@ namespace dlib ...@@ -149,7 +149,7 @@ namespace dlib
- #get_max_num_epochs() == num - #get_max_num_epochs() == num
!*/ !*/
void set_setep_size ( void set_step_size (
double ss double ss
); );
/*! /*!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment