"examples/vscode:/vscode.git/clone" did not exist on "26b631805fd00a41bf352f3a01cf47bc4ec9db45"
Commit aaeb52ba authored by Davis King's avatar Davis King
Browse files

Updated the interface to allow the user to set different loss values for

false alarming vs getting a correct detection.
parent 1de36ea2
...@@ -27,11 +27,15 @@ namespace dlib ...@@ -27,11 +27,15 @@ namespace dlib
const feature_extractor& fe_ const feature_extractor& fe_
) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_)) ) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_))
{ {
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
} }
structural_sequence_segmentation_trainer ( structural_sequence_segmentation_trainer (
) )
{ {
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
} }
const feature_extractor& get_feature_extractor ( const feature_extractor& get_feature_extractor (
...@@ -127,6 +131,63 @@ namespace dlib ...@@ -127,6 +131,63 @@ namespace dlib
return trainer.get_c(); return trainer.get_c();
} }
void set_loss_per_missed_segment (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_segment = loss;
if (feature_extractor::use_BIO_model)
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
}
else
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
trainer.set_loss(impl_ss::LAST, loss_per_missed_segment);
trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment);
}
}
double get_loss_per_missed_segment (
) const
{
return loss_per_missed_segment;
}
void set_loss_per_false_alarm (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_alarm = loss;
trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm);
}
double get_loss_per_false_alarm (
) const
{
return loss_per_false_alarm;
}
const sequence_segmenter<feature_extractor> train( const sequence_segmenter<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
const std::vector<segmented_sequence_type>& y const std::vector<segmented_sequence_type>& y
...@@ -198,6 +259,8 @@ namespace dlib ...@@ -198,6 +259,8 @@ namespace dlib
private: private:
structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer; structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer;
double loss_per_missed_segment;
double loss_per_false_alarm;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -47,6 +47,8 @@ namespace dlib ...@@ -47,6 +47,8 @@ namespace dlib
- #get_num_threads() == 2 - #get_num_threads() == 2
- #get_max_cache_size() == 40 - #get_max_cache_size() == 40
- #get_feature_extractor() == a default initialized feature_extractor - #get_feature_extractor() == a default initialized feature_extractor
- #get_loss_per_missed_segment() == 1
- #get_loss_per_false_alarm() == 1
!*/ !*/
explicit structural_sequence_segmentation_trainer ( explicit structural_sequence_segmentation_trainer (
...@@ -60,6 +62,8 @@ namespace dlib ...@@ -60,6 +62,8 @@ namespace dlib
- #get_num_threads() == 2 - #get_num_threads() == 2
- #get_max_cache_size() == 40 - #get_max_cache_size() == 40
- #get_feature_extractor() == fe - #get_feature_extractor() == fe
- #get_loss_per_missed_segment() == 1
- #get_loss_per_false_alarm() == 1
!*/ !*/
const feature_extractor& get_feature_extractor ( const feature_extractor& get_feature_extractor (
...@@ -178,6 +182,44 @@ namespace dlib ...@@ -178,6 +182,44 @@ namespace dlib
generalization. generalization.
!*/ !*/
void set_loss_per_missed_segment (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_per_missed_segment() == loss
!*/
double get_loss_per_missed_segment (
) const;
/*!
ensures
- returns the amount of loss incurred for failing to detect a segment. The
larger the loss the more important it is to detect all the segments.
!*/
void set_loss_per_false_alarm (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_per_false_alarm() == loss
!*/
double get_loss_per_false_alarm (
) const;
/*!
ensures
- returns the amount of loss incurred for outputting a false detection. The
larger the loss the more important it is to avoid outputting false
detections.
!*/
const sequence_segmenter<feature_extractor> train( const sequence_segmenter<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
const std::vector<segmented_sequence_type>& y const std::vector<segmented_sequence_type>& y
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment