Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
03ec260c
Commit
03ec260c
authored
Jan 17, 2013
by
Davis King
Browse files
reformatted comments.
parent
91e8594b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
63 deletions
+66
-63
examples/svm_ex.cpp
examples/svm_ex.cpp
+66
-63
No files found.
examples/svm_ex.cpp
View file @
03ec260c
...
@@ -27,19 +27,19 @@ using namespace dlib;
...
@@ -27,19 +27,19 @@ using namespace dlib;
int
main
()
int
main
()
{
{
// The svm functions use column vectors to contain a lot of the data on which they
// The svm functions use column vectors to contain a lot of the data on which they
// operate. So the first thing we do here is declare a convenient typedef.
// operate. So the first thing we do here is declare a convenient typedef.
// This typedef declares a matrix with 2 rows and 1 column. It will be the
// This typedef declares a matrix with 2 rows and 1 column. It will be the
object that
//
object that
contains each of our 2 dimensional samples. (Note that if you wanted
// contains each of our 2 dimensional samples. (Note that if you wanted
more than 2
//
more than 2
features in this vector you can simply change the 2 to something else.
// features in this vector you can simply change the 2 to something else.
Or if you
//
Or if you
don't know how many features you want until runtime then you can put a 0
// don't know how many features you want until runtime then you can put a 0
here and
//
here and
use the matrix.set_size() member function)
// use the matrix.set_size() member function)
typedef
matrix
<
double
,
2
,
1
>
sample_type
;
typedef
matrix
<
double
,
2
,
1
>
sample_type
;
// This is a typedef for the type of kernel we are going to use in this example.
// This is a typedef for the type of kernel we are going to use in this example.
In
//
In
this case I have selected the radial basis kernel that can operate on our
// this case I have selected the radial basis kernel that can operate on our
2D
//
2D
sample_type objects
// sample_type objects
typedef
radial_basis_kernel
<
sample_type
>
kernel_type
;
typedef
radial_basis_kernel
<
sample_type
>
kernel_type
;
...
@@ -47,9 +47,9 @@ int main()
...
@@ -47,9 +47,9 @@ int main()
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
double
>
labels
;
std
::
vector
<
double
>
labels
;
// Now lets put some data into our samples and labels objects. We do this
// Now lets put some data into our samples and labels objects. We do this
by looping
//
by looping
over a bunch of points and labeling them according to their
// over a bunch of points and labeling them according to their
distance from the
//
distance from the
origin.
// origin.
for
(
int
r
=
-
20
;
r
<=
20
;
++
r
)
for
(
int
r
=
-
20
;
r
<=
20
;
++
r
)
{
{
for
(
int
c
=
-
20
;
c
<=
20
;
++
c
)
for
(
int
c
=
-
20
;
c
<=
20
;
++
c
)
...
@@ -69,11 +69,11 @@ int main()
...
@@ -69,11 +69,11 @@ int main()
}
}
// Here we normalize all the samples by subtracting their mean and dividing by their
standard deviation.
// Here we normalize all the samples by subtracting their mean and dividing by their
// This is generally a good idea since it often heads off
numerical stability problems and also
//
standard deviation.
This is generally a good idea since it often heads off
// prevents one large feature from smothering
others. Doing this doesn't matter much in this example
//
numerical stability problems and also
prevents one large feature from smothering
//
so I'm just doing this here so you can see an easy way to accomplish
this
with
//
others. Doing this doesn't matter much in this example so I'm just doing
this
here
// the library.
//
so you can see an easy way to accomplish this with
the library.
vector_normalizer
<
sample_type
>
normalizer
;
vector_normalizer
<
sample_type
>
normalizer
;
// let the normalizer learn the mean and standard deviation of the samples
// let the normalizer learn the mean and standard deviation of the samples
normalizer
.
train
(
samples
);
normalizer
.
train
(
samples
);
...
@@ -82,19 +82,20 @@ int main()
...
@@ -82,19 +82,20 @@ int main()
samples
[
i
]
=
normalizer
(
samples
[
i
]);
samples
[
i
]
=
normalizer
(
samples
[
i
]);
// Now that we have some data we want to train on it. However, there are two parameters to the
// Now that we have some data we want to train on it. However, there are two
// training. These are the nu and gamma parameters. Our choice for these parameters will
// parameters to the training. These are the nu and gamma parameters. Our choice for
// influence how good the resulting decision function is. To test how good a particular choice
// these parameters will influence how good the resulting decision function is. To
// of these parameters is we can use the cross_validate_trainer() function to perform n-fold cross
// test how good a particular choice of these parameters is we can use the
// validation on our training data. However, there is a problem with the way we have sampled
// cross_validate_trainer() function to perform n-fold cross validation on our training
// our distribution above. The problem is that there is a definite ordering to the samples.
// data. However, there is a problem with the way we have sampled our distribution
// That is, the first half of the samples look like they are from a different distribution
// above. The problem is that there is a definite ordering to the samples. That is,
// than the second half. This would screw up the cross validation process but we can
// the first half of the samples look like they are from a different distribution than
// fix it by randomizing the order of the samples with the following function call.
// the second half. This would screw up the cross validation process but we can fix it
// by randomizing the order of the samples with the following function call.
randomize_samples
(
samples
,
labels
);
randomize_samples
(
samples
,
labels
);
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
// labels in the training data. This function finds that value.
// labels in the training data. This function finds that value.
const
double
max_nu
=
maximum_nu
(
labels
);
const
double
max_nu
=
maximum_nu
(
labels
);
...
@@ -102,8 +103,8 @@ int main()
...
@@ -102,8 +103,8 @@ int main()
svm_nu_trainer
<
kernel_type
>
trainer
;
svm_nu_trainer
<
kernel_type
>
trainer
;
// Now we loop over some different nu and gamma values to see how good they are. Note
// Now we loop over some different nu and gamma values to see how good they are. Note
// that this is a very simple way to try out a few possible parameter choices. You
// that this is a very simple way to try out a few possible parameter choices. You
// should look at the model_selection_ex.cpp program for examples of more sophisticated
// should look at the model_selection_ex.cpp program for examples of more sophisticated
// strategies for determining good parameter choices.
// strategies for determining good parameter choices.
cout
<<
"doing cross validation"
<<
endl
;
cout
<<
"doing cross validation"
<<
endl
;
for
(
double
gamma
=
0.00001
;
gamma
<=
1
;
gamma
*=
5
)
for
(
double
gamma
=
0.00001
;
gamma
<=
1
;
gamma
*=
5
)
...
@@ -115,29 +116,31 @@ int main()
...
@@ -115,29 +116,31 @@ int main()
trainer
.
set_nu
(
nu
);
trainer
.
set_nu
(
nu
);
cout
<<
"gamma: "
<<
gamma
<<
" nu: "
<<
nu
;
cout
<<
"gamma: "
<<
gamma
<<
" nu: "
<<
nu
;
// Print out the cross validation accuracy for 3-fold cross validation using the current gamma and nu.
// Print out the cross validation accuracy for 3-fold cross validation using
// cross_validate_trainer() returns a row vector. The first element of the vector is the fraction
// the current gamma and nu. cross_validate_trainer() returns a row vector.
// of +1 training examples correctly classified and the second number is the fraction of -1 training
// The first element of the vector is the fraction of +1 training examples
// correctly classified and the second number is the fraction of -1 training
// examples correctly classified.
// examples correctly classified.
cout
<<
" cross validation accuracy: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
cout
<<
" cross validation accuracy: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
}
}
}
}
// From looking at the output of the above loop it turns out that a good value for
// From looking at the output of the above loop it turns out that a good value for
nu
//
nu
and gamma for this problem is 0.15625 for both. So that is what we will use.
// and gamma for this problem is 0.15625 for both. So that is what we will use.
// Now we train on the full set of data and obtain the resulting decision function. We use the
// Now we train on the full set of data and obtain the resulting decision function. We
// value of 0.15625 for nu and gamma. The decision function will return values >= 0 for samples it predicts
// use the value of 0.15625 for nu and gamma. The decision function will return values
// are in the +1 class and numbers < 0 for samples it predicts to be in the -1 class.
// >= 0 for samples it predicts are in the +1 class and numbers < 0 for samples it
// predicts to be in the -1 class.
trainer
.
set_kernel
(
kernel_type
(
0.15625
));
trainer
.
set_kernel
(
kernel_type
(
0.15625
));
trainer
.
set_nu
(
0.15625
);
trainer
.
set_nu
(
0.15625
);
typedef
decision_function
<
kernel_type
>
dec_funct_type
;
typedef
decision_function
<
kernel_type
>
dec_funct_type
;
typedef
normalized_function
<
dec_funct_type
>
funct_type
;
typedef
normalized_function
<
dec_funct_type
>
funct_type
;
// Here we are making an instance of the normalized_function object. This object
provides a convenient
// Here we are making an instance of the normalized_function object. This object
// way to store the vector normalization information along with
the decision function we are
//
provides a convenient
way to store the vector normalization information along with
// going to learn.
//
the decision function we are
going to learn.
funct_type
learned_function
;
funct_type
learned_function
;
learned_function
.
normalizer
=
normalizer
;
// save normalization information
learned_function
.
normalizer
=
normalizer
;
// save normalization information
learned_function
.
function
=
trainer
.
train
(
samples
,
labels
);
// perform the actual SVM training and save the results
learned_function
.
function
=
trainer
.
train
(
samples
,
labels
);
// perform the actual SVM training and save the results
...
@@ -166,8 +169,8 @@ int main()
...
@@ -166,8 +169,8 @@ int main()
cout
<<
"This sample should be < 0 and it is classified as a "
<<
learned_function
(
sample
)
<<
endl
;
cout
<<
"This sample should be < 0 and it is classified as a "
<<
learned_function
(
sample
)
<<
endl
;
// We can also train a decision function that reports a well conditioned probability
// We can also train a decision function that reports a well conditioned probability
// instead of just a number > 0 for the +1 class and < 0 for the -1 class. An example
// instead of just a number > 0 for the +1 class and < 0 for the -1 class. An example
// of doing that follows:
// of doing that follows:
typedef
probabilistic_decision_function
<
kernel_type
>
probabilistic_funct_type
;
typedef
probabilistic_decision_function
<
kernel_type
>
probabilistic_funct_type
;
typedef
normalized_function
<
probabilistic_funct_type
>
pfunct_type
;
typedef
normalized_function
<
probabilistic_funct_type
>
pfunct_type
;
...
@@ -200,8 +203,9 @@ int main()
...
@@ -200,8 +203,9 @@ int main()
// Another thing that is worth knowing is that just about everything in dlib is serializable.
// Another thing that is worth knowing is that just about everything in dlib is
// So for example, you can save the learned_pfunct object to disk and recall it later like so:
// serializable. So for example, you can save the learned_pfunct object to disk and
// recall it later like so:
ofstream
fout
(
"saved_function.dat"
,
ios
::
binary
);
ofstream
fout
(
"saved_function.dat"
,
ios
::
binary
);
serialize
(
learned_pfunct
,
fout
);
serialize
(
learned_pfunct
,
fout
);
fout
.
close
();
fout
.
close
();
...
@@ -210,27 +214,27 @@ int main()
...
@@ -210,27 +214,27 @@ int main()
ifstream
fin
(
"saved_function.dat"
,
ios
::
binary
);
ifstream
fin
(
"saved_function.dat"
,
ios
::
binary
);
deserialize
(
learned_pfunct
,
fin
);
deserialize
(
learned_pfunct
,
fin
);
// Note that there is also an example program that comes with dlib called the
file_to_code_ex.cpp
// Note that there is also an example program that comes with dlib called the
// example. It is a simple program that takes a file and outputs a
piece of C++ code
//
file_to_code_ex.cpp
example. It is a simple program that takes a file and outputs a
// that is able to fully reproduce the file's contents in the form of
a std::string object.
//
piece of C++ code
that is able to fully reproduce the file's contents in the form of
// So you can use that along with the std::istringstream to save
learned decision functions
//
a std::string object.
So you can use that along with the std::istringstream to save
// inside your actual C++ code files if you want.
//
learned decision functions
inside your actual C++ code files if you want.
// Lastly, note that the decision functions we trained above involved well over 200
// Lastly, note that the decision functions we trained above involved well over 200
// basis vectors. Support vector machines in general tend to find decision functions
// basis vectors. Support vector machines in general tend to find decision functions
// that involve a lot of basis vectors. This is significant because the more
// that involve a lot of basis vectors. This is significant because the more
basis
//
basis
vectors in a decision function, the longer it takes to classify new examples.
// vectors in a decision function, the longer it takes to classify new examples.
So
//
So
dlib provides the ability to find an approximation to the normal output of a
// dlib provides the ability to find an approximation to the normal output of a
trainer
//
trainer
using fewer basis vectors.
// using fewer basis vectors.
// Here we determine the cross validation accuracy when we approximate the output
// Here we determine the cross validation accuracy when we approximate the output
using
//
using
only 10 basis vectors. To do this we use the reduced2() function. It
// only 10 basis vectors. To do this we use the reduced2() function. It
takes a
//
takes a
trainer object and the number of basis vectors to use and returns
// trainer object and the number of basis vectors to use and returns
a new trainer
//
a new trainer
object that applies the necessary post processing during the creation
// object that applies the necessary post processing during the creation
of decision
//
of decision
function objects.
// function objects.
cout
<<
"
\n
cross validation accuracy with only 10 support vectors: "
cout
<<
"
\n
cross validation accuracy with only 10 support vectors: "
<<
cross_validate_trainer
(
reduced2
(
trainer
,
10
),
samples
,
labels
,
3
);
<<
cross_validate_trainer
(
reduced2
(
trainer
,
10
),
samples
,
labels
,
3
);
...
@@ -238,9 +242,8 @@ int main()
...
@@ -238,9 +242,8 @@ int main()
cout
<<
"cross validation accuracy with all the original support vectors: "
cout
<<
"cross validation accuracy with all the original support vectors: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
// When you run this program you should see that, for this problem, you can reduce
// When you run this program you should see that, for this problem, you can reduce the
// the number of basis vectors down to 10 without hurting the cross validation
// number of basis vectors down to 10 without hurting the cross validation accuracy.
// accuracy.
// To get the reduced decision function out we would just do this:
// To get the reduced decision function out we would just do this:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment