Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
2a27b690
Commit
2a27b690
authored
Jun 06, 2018
by
Davis King
Browse files
Added auto_train_rbf_classifier()
parent
c14dca07
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
166 additions
and
0 deletions
+166
-0
dlib/CMakeLists.txt
dlib/CMakeLists.txt
+1
-0
dlib/all/source.cpp
dlib/all/source.cpp
+1
-0
dlib/svm.h
dlib/svm.h
+1
-0
dlib/svm/auto.cpp
dlib/svm/auto.cpp
+102
-0
dlib/svm/auto.h
dlib/svm/auto.h
+25
-0
dlib/svm/auto_abstract.h
dlib/svm/auto_abstract.h
+36
-0
No files found.
dlib/CMakeLists.txt
View file @
2a27b690
...
...
@@ -243,6 +243,7 @@ if (NOT TARGET dlib)
global_optimization/global_function_search.cpp
filtering/kalman_filter.cpp
test_for_odr_violations.cpp
svm/auto.cpp
)
...
...
dlib/all/source.cpp
View file @
2a27b690
...
...
@@ -90,6 +90,7 @@
#include "../data_io/mnist.cpp"
#include "../global_optimization/global_function_search.cpp"
#include "../filtering/kalman_filter.cpp"
#include "../svm/auto.cpp"
#define DLIB_ALL_SOURCE_END
...
...
dlib/svm.h
View file @
2a27b690
...
...
@@ -54,6 +54,7 @@
#include "svm/active_learning.h"
#include "svm/svr_linear_trainer.h"
#include "svm/sequence_segmenter.h"
#include "svm/auto.h"
#endif // DLIB_SVm_HEADER
...
...
dlib/svm/auto.cpp
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AUTO_LEARnING_CPP_
#define DLIB_AUTO_LEARnING_CPP_
#include "auto.h"
#include "../global_optimization.h"
#include "svm_c_trainer.h"
#include <iostream>
#include <thread>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
)
{
const
auto
num_positive_training_samples
=
sum
(
mat
(
y
)
>
0
);
const
auto
num_negative_training_samples
=
sum
(
mat
(
y
)
<
0
);
DLIB_CASSERT
(
num_positive_training_samples
>=
6
&&
num_negative_training_samples
>=
6
,
"You must provide at least 6 examples of each class to this training routine."
);
// make sure requires clause is not broken
DLIB_CASSERT
(
is_binary_classification_problem
(
x
,
y
)
==
true
,
"
\t
decision_function svm_c_trainer::train(x,y)"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
x.size(): "
<<
x
.
size
()
<<
"
\n\t
y.size(): "
<<
y
.
size
()
<<
"
\n\t
is_binary_classification_problem(x,y): "
<<
is_binary_classification_problem
(
x
,
y
)
);
randomize_samples
(
x
,
y
);
vector_normalizer
<
matrix
<
double
,
0
,
1
>>
normalizer
;
// let the normalizer learn the mean and standard deviation of the samples
normalizer
.
train
(
x
);
for
(
auto
&
samp
:
x
)
samp
=
normalizer
(
samp
);
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
df
;
df
.
normalizer
=
normalizer
;
typedef
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>
kernel_type
;
std
::
mutex
m
;
auto
cross_validation_score
=
[
&
](
const
double
gamma
,
const
double
c1
,
const
double
c2
)
{
svm_c_trainer
<
kernel_type
>
trainer
;
trainer
.
set_kernel
(
kernel_type
(
gamma
));
trainer
.
set_c_class1
(
c1
);
trainer
.
set_c_class2
(
c2
);
// Finally, perform 6-fold cross validation and then print and return the results.
matrix
<
double
>
result
=
cross_validate_trainer
(
trainer
,
x
,
y
,
6
);
if
(
be_verbose
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
m
);
std
::
cout
<<
"gamma: "
<<
std
::
setw
(
11
)
<<
gamma
<<
" c1: "
<<
std
::
setw
(
11
)
<<
c1
<<
" c2: "
<<
std
::
setw
(
11
)
<<
c2
<<
" cross validation accuracy: "
<<
result
<<
std
::
flush
;
}
// return the f1 score plus a penalty for picking large parameter settings
// since those are, a priori less likely to generalize.
return
2
*
prod
(
result
)
/
sum
(
result
)
-
std
::
max
(
c1
,
c2
)
/
1e12
-
gamma
/
1e8
;
};
std
::
cout
<<
"Searching for best RBF-SVM training parameters..."
<<
std
::
endl
;
auto
result
=
find_max_global
(
default_thread_pool
(),
cross_validation_score
,
{
1e-5
,
1e-5
,
1e-5
},
// lower bound constraints on gamma, c1, and c2, respectively
{
100
,
1e6
,
1e6
},
// upper bound constraints on gamma, c1, and c2, respectively
max_runtime
);
double
best_gamma
=
result
.
x
(
0
);
double
best_c1
=
result
.
x
(
1
);
double
best_c2
=
result
.
x
(
2
);
std
::
cout
<<
" best cross-validation score: "
<<
result
.
y
<<
std
::
endl
;
std
::
cout
<<
" best gamma: "
<<
best_gamma
<<
" best c1: "
<<
best_c1
<<
" best c2: "
<<
best_c2
<<
std
::
endl
;
svm_c_trainer
<
kernel_type
>
trainer
;
trainer
.
set_kernel
(
kernel_type
(
best_gamma
));
trainer
.
set_c_class1
(
best_c1
);
trainer
.
set_c_class2
(
best_c2
);
std
::
cout
<<
"Training final classifier with best parameters..."
<<
std
::
endl
;
df
.
function
=
trainer
.
train
(
x
,
y
);
return
df
;
}
}
#endif // DLIB_AUTO_LEARnING_CPP_
dlib/svm/auto.h
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AUTO_LEARnING_Hh_
#define DLIB_AUTO_LEARnING_Hh_
#include "auto_abstract.h"
#include "../algs.h"
#include "function.h"
#include "kernel.h"
#include <chrono>
#include <vector>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
=
true
);
}
#endif // DLIB_AUTO_LEARnING_Hh_
dlib/svm/auto_abstract.h
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_AUTO_LEARnING_ABSTRACT_Hh_
#ifdef DLIB_AUTO_LEARnING_ABSTRACT_Hh_
#include "kernel_abstract.h"
#include "function_abstract.h"
#include <chrono>
#include <vector>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
=
true
);
/*!
requires
- is_binary_classification_problem(x,y) == true
- y contains at least 6 examples of each class.
ensures
- This routine trains a radial basis function SVM on the given binary
classification training data. It uses the svm_c_trainer to do this. It also
uses find_max_global() and 6-fold cross-validation to automatically determine
the best settings of the SVM's hyper parameters.
- The hyperparameter search will run for about max_runtime and will print
messages to the screen as it runs if be_verbose==true.
!*/
}
#endif // DLIB_AUTO_LEARnING_ABSTRACT_Hh_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment