Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
3eb0d973
Commit
3eb0d973
authored
Sep 15, 2011
by
Davis King
Browse files
Added the cross_validate_object_detection_trainer() and test_object_detection_function()
routines.
parent
0aa89e07
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
332 additions
and
0 deletions
+332
-0
dlib/svm.h
dlib/svm.h
+1
-0
dlib/svm/cross_validate_object_detection_trainer.h
dlib/svm/cross_validate_object_detection_trainer.h
+242
-0
dlib/svm/cross_validate_object_detection_trainer_abstract.h
dlib/svm/cross_validate_object_detection_trainer_abstract.h
+89
-0
No files found.
dlib/svm.h
View file @
3eb0d973
...
...
@@ -33,6 +33,7 @@
#include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/cross_validate_object_detection_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
...
...
dlib/svm/cross_validate_object_detection_trainer.h
0 → 100644
View file @
3eb0d973
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#define DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#include "cross_validate_object_detection_trainer_abstract.h"
#include <vector>
#include "../matrix.h"
#include "svm.h"
#include "../geometry.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
namespace
impl
{
unsigned
long
number_of_truth_hits
(
const
std
::
vector
<
rectangle
>&
truth_boxes
,
const
std
::
vector
<
rectangle
>&
boxes
,
const
double
overlap_eps
)
/*!
requires
- 0 < overlap_eps <= 1
ensures
- returns the number of elements in truth_boxes which are overlapped by an
element of boxes. In this context, two boxes, A and B, overlap if and only if
the following quantity is greater than overlap_eps:
A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
!*/
{
if
(
boxes
.
size
()
==
0
)
return
0
;
unsigned
long
count
=
0
;
std
::
vector
<
bool
>
used
(
boxes
.
size
(),
false
);
for
(
unsigned
long
i
=
0
;
i
<
truth_boxes
.
size
();
++
i
)
{
unsigned
long
best_idx
=
0
;
double
best_overlap
=
0
;
for
(
unsigned
long
j
=
0
;
j
<
boxes
.
size
();
++
j
)
{
if
(
used
[
j
])
continue
;
const
double
overlap
=
truth_boxes
[
i
].
intersect
(
boxes
[
j
]).
area
()
/
(
double
)(
truth_boxes
[
i
]
+
boxes
[
j
]).
area
();
if
(
overlap
>
best_overlap
)
{
best_overlap
=
overlap
;
best_idx
=
j
;
}
}
if
(
best_overlap
>
overlap_eps
&&
used
[
best_idx
]
==
false
)
{
used
[
best_idx
]
=
true
;
++
count
;
}
}
return
count
;
}
}
// ----------------------------------------------------------------------------------------
template
<
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
double
overlap_eps
=
0.5
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_learning_problem
(
images
,
truth_rects
)
==
true
&&
0
<
overlap_eps
&&
overlap_eps
<=
1
,
"
\t
matrix test_object_detection_function()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
is_learning_problem(images,truth_rects): "
<<
is_learning_problem
(
images
,
truth_rects
)
<<
"
\n\t
overlap_eps: "
<<
overlap_eps
);
double
correct_hits
=
0
;
double
total_hits
=
0
;
double
total_true_targets
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
images
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
i
]);
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_rects
[
i
],
hits
,
overlap_eps
);
total_true_targets
+=
truth_rects
[
i
].
size
();
}
double
precision
,
recall
;
if
(
total_hits
==
0
)
precision
=
1
;
else
precision
=
correct_hits
/
total_hits
;
if
(
total_true_targets
==
0
)
recall
=
1
;
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
return
res
;
}
// ----------------------------------------------------------------------------------------
namespace
impl
{
template
<
typename
array_type
>
struct
array_subset_helper
{
array_subset_helper
(
const
array_type
&
array_
,
const
std
::
vector
<
unsigned
long
>&
idx_set_
)
:
array
(
array_
),
idx_set
(
idx_set_
)
{
}
unsigned
long
size
()
const
{
return
idx_set
.
size
();
}
typedef
typename
array_type
::
type
type
;
const
type
&
operator
[]
(
unsigned
long
idx
)
const
{
return
array
[
idx_set
[
idx
]];
}
private:
const
array_type
&
array
;
const
std
::
vector
<
unsigned
long
>&
idx_set
;
};
}
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
long
folds
,
const
double
overlap_eps
=
0.5
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_learning_problem
(
images
,
truth_rects
)
==
true
&&
0
<
overlap_eps
&&
overlap_eps
<=
1
&&
1
<
folds
&&
folds
<=
images
.
size
(),
"
\t
matrix cross_validate_object_detection_trainer()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
is_learning_problem(images,truth_rects): "
<<
is_learning_problem
(
images
,
truth_rects
)
<<
"
\n\t
overlap_eps: "
<<
overlap_eps
<<
"
\n\t
folds: "
<<
folds
);
double
correct_hits
=
0
;
double
total_hits
=
0
;
double
total_true_targets
=
0
;
const
long
test_size
=
images
.
size
()
/
folds
;
unsigned
long
test_idx
=
0
;
for
(
long
iter
=
0
;
iter
<
folds
;
++
iter
)
{
std
::
vector
<
unsigned
long
>
train_idx_set
;
std
::
vector
<
unsigned
long
>
test_idx_set
;
for
(
unsigned
long
i
=
0
;
i
<
test_size
;
++
i
)
test_idx_set
.
push_back
(
test_idx
++
);
unsigned
long
train_idx
=
test_idx
%
images
.
size
();
std
::
vector
<
std
::
vector
<
rectangle
>
>
training_rects
;
for
(
unsigned
long
i
=
0
;
i
<
images
.
size
()
-
test_size
;
++
i
)
{
training_rects
.
push_back
(
truth_rects
[
train_idx
]);
train_idx_set
.
push_back
(
train_idx
);
train_idx
=
(
train_idx
+
1
)
%
images
.
size
();
}
impl
::
array_subset_helper
<
image_array_type
>
array_subset
(
images
,
train_idx_set
);
const
typename
trainer_type
::
trained_function_type
&
detector
=
trainer
.
train
(
array_subset
,
training_rects
);
for
(
unsigned
long
i
=
0
;
i
<
test_idx_set
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
test_idx_set
[
i
]]);
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_rects
[
test_idx_set
[
i
]],
hits
,
overlap_eps
);
total_true_targets
+=
truth_rects
[
test_idx_set
[
i
]].
size
();
}
}
double
precision
,
recall
;
if
(
total_hits
==
0
)
precision
=
1
;
else
precision
=
correct_hits
/
total_hits
;
if
(
total_true_targets
==
0
)
recall
=
1
;
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
return
res
;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
dlib/svm/cross_validate_object_detection_trainer_abstract.h
0 → 100644
View file @
3eb0d973
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "../geometry.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
template
<
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
double
overlap_eps
=
0.5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- object_detector_type == some kind of object detector function object
(e.g. object_detector)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Tests the given detector against the supplied object detection problem
and returns the precision and recall. Note that the task is to predict,
for each images[i], the set of object locations given by truth_rects[i].
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
which correspond to a real target. A value of 1 means the detector
never produces any false alarms while a value of 0 means it only
produces false alarms.
- M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the
detector. A value of 1 means the detector found all the targets
in truth_rects while a value of 0 means the detector didn't locate
any of the targets.
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
!*/
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
long
folds
,
const
double
overlap_eps
=
0.5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- 1 < folds <= images.size()
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision and recall of the trained
detectors and is defined identically to the test_object_detection_function()
routine defined at the top of this file.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment