Commit 3c7c8fee authored by Davis King's avatar Davis King
Browse files

Refactored a bunch of the svm training code into a much cleaner form.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402380
parent b51cc5c2
This diff is collapsed.
...@@ -17,8 +17,24 @@ namespace dlib ...@@ -17,8 +17,24 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// Functions that perform SVM training
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class invalid_svm_nu_error : public dlib::error
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is an exception class used to indicate that a
value of nu used for svm training is incompatible with a
particular data set.
this->nu will be set to the invalid value of nu used.
!*/
public:
invalid_svm_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {};
const double nu;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -29,102 +45,242 @@ namespace dlib ...@@ -29,102 +45,242 @@ namespace dlib
); );
/*! /*!
requires requires
- T == a matrix object - T == a matrix object or an object convertible to a matrix via
vector_to_matrix()
- y.nc() == 1 - y.nc() == 1
- y.nr() > 1 - y.nr() > 1
- for all valid i: - for all valid i:
- y(i) == -1 or +1 - y(i) == -1 or +1
ensures ensures
- returns the maximum valid nu that can be used with svm_nu_train(). - returns the maximum valid nu that can be used with the svm_nu_trainer and
the training set labels from the given y vector.
(i.e. 2.0*min(number of +1 examples in y, number of -1 examples in y)/y.nr()) (i.e. 2.0*min(number of +1 examples in y, number of -1 examples in y)/y.nr())
!*/ !*/
template < // ----------------------------------------------------------------------------------------
typename K
bool template <
typename T,
typename U
> >
const decision_function<K> svm_nu_train ( bool is_binary_classification_problem (
const typename decision_function<K>::sample_vector_type& x, const T& x,
const typename decision_function<K>::scalar_vector_type& y, const U& x_labels
const K& kernel_function,
const typename K::scalar_type nu,
const long cache_size = 200,
const typename K::scalar_type eps = 0.001
); );
/*! /*!
requires requires
- eps > 0 - T == a matrix or something convertible to a matrix via vector_to_matrix()
- x.nc() == 1 (i.e. x is a column vector) - U == a matrix or something convertible to a matrix via vector_to_matrix()
- y.nc() == 1 (i.e. y is a column vector)
- x.nr() == y.nr()
- x.nr() > 1
- cache_size > 0
- for all valid i:
- y(i) == -1 or +1
- y(i) is the class that should be assigned to training example x(i)
- 0 < nu < maximum_nu(y)
- kernel_function == a kernel function object type as defined at the
top of dlib/svm/kernel_abstract.h
ensures ensures
- trains a nu support vector classifier given the training samples in x and - returns true if all of the following are true and false otherwise:
labels in y. Training is done when the error is less than eps. - x.nc() == 1 (i.e. x is a column vector)
- caches approximately at most cache_size megabytes of the kernel matrix. - x_labels.nc() == 1 (i.e. x_labels is a column vector)
(bigger values of this may make training go faster but doesn't affect the - x.nr() == x_labels.nr()
result. However, too big a value will cause you to run out of memory.) - x.nr() > 1
- returns a decision function F with the following properties: - for all valid i:
- if (new_x is a sample predicted have +1 label) then - x_labels(i) == -1 or +1
- F(new_x) >= 0
- else
- F(new_x) < 0
!*/ !*/
/* // ----------------------------------------------------------------------------------------
The implementation of the nu-svm training algorithm used by this library is based // ----------------------------------------------------------------------------------------
on the following excellent papers: // ----------------------------------------------------------------------------------------
- Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
- Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector template <
machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm typename K
*/ >
class svm_nu_trainer
{
/*!
REQUIREMENTS ON K
is a kernel function object as defined in dlib/svm/kernel_abstract.h
WHAT THIS OBJECT REPRESENTS
This object implements a trainer for a nu support vector machine for
solving binary classification problems.
The implementation of the nu-svm training algorithm used by this object is based
on the following excellent papers:
- Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
- Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector
machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm
!*/
public:
typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
svm_nu_trainer (
);
/*!
ensures
- This object is properly initialized and ready to be used
to train a support vector machine.
- #get_kernel() == kernel_type()
- #get_nu() == 0.1
- #get_cache_size() == 200
- #get_epsilon() == 0.001
!*/
svm_nu_trainer (
const kernel_type& kernel,
const scalar_type& nu
);
/*!
requires
- 0 < nu <= 1
ensures
- This object is properly initialized and ready to be used
to train a support vector machine.
- #get_kernel() == kernel
- #get_nu() == nu
- #get_cache_size() == 200
- #get_epsilon() == 0.001
!*/
void set_cache_size (
long cache_size
);
/*!
requires
- cache_size > 0
ensures
- #get_cache_size() == cache_size
!*/
const long get_cache_size (
) const;
/*!
ensures
- returns the number of megabytes of cache this object will use
when it performs training via the this->train() function.
(bigger values of this may make training go faster but doesn't affect
the result. However, too big a value will cause you to run out of
memory obviously.)
!*/
void set_epsilon (
scalar_type eps
);
/*!
requires
- eps > 0
ensures
- #get_epsilon() == eps
!*/
const scalar_type get_epsilon (
) const;
/*!
ensures
- returns the error epsilon that determines when training should stop.
Generally a good value for this is 0.001. Smaller values may result
in a more accurate solution but take longer to execute.
!*/
void set_kernel (
const kernel_type& k
);
/*!
ensures
- #get_kernel() == k
!*/
const kernel_type& get_kernel (
) const;
/*!
ensures
- returns a copy of the kernel function in use by this object
!*/
void set_nu (
scalar_type nu
);
/*!
requires
- 0 < nu <= 1
ensures
- #get_nu() == nu
!*/
const scalar_type get_nu (
) const;
/*!
ensures
- returns the nu svm parameter. This is a value between 0 and
1. It is the parameter that determines the trade off between
trying to fit the training data exactly or allowing more errors
but hopefully improving the generalization ability of the
resulting classifier. For more information you should consult
the papers referenced above.
!*/
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const;
/*!
requires
- is_binary_classification_problem(x,y) == true
ensures
- trains a nu support vector classifier given the training samples in x and
labels in y. Training is done when the error is less than get_epsilon().
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
- else
- F(new_x) < 0
throws
- invalid_svm_nu_error
This exception is thrown if get_nu() > maximum_nu(y)
- std::bad_alloc
!*/
void swap (
svm_nu_trainer& item
);
/*!
ensures
- swaps *this and item
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename K typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
> >
const probabilistic_decision_function<K> svm_nu_train_prob ( const probabilistic_decision_function<typename trainer_type::kernel_type>
const typename decision_function<K>::sample_vector_type& x, train_probabilistic_decision_function (
const typename decision_function<K>::scalar_vector_type& y, const trainer_type& trainer,
const K& kernel_function, const in_sample_vector_type& x,
const typename K::scalar_type nu, const in_scalar_vector_type& y,
const long folds, const long folds
const long cache_size = 200, )
const typename K::scalar_type eps = 0.001
);
/*! /*!
requires requires
- eps > 0
- 1 < folds <= x.nr() - 1 < folds <= x.nr()
- x.nc() == 1 (i.e. x is a column vector) - is_binary_classification_problem(x,y) == true
- y.nc() == 1 (i.e. y is a column vector) - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
- x.nr() == y.nr()
- x.nr() > 1
- cache_size > 0
- for all valid i:
- y(i) == -1 or +1
- y(i) is the class that should be assigned to training example x(i)
- 0 < nu < maximum_nu(y)
- kernel_function == a kernel function object type as defined at the
top of dlib/svm/kernel_abstract.h
ensures ensures
- trains a nu support vector classifier given the training samples in x and - trains a nu support vector classifier given the training samples in x and
labels in y. Training is done when the error is less than eps. labels in y.
- caches approximately at most cache_size megabytes of the kernel matrix. - returns a probabilistic_decision_function that represents the trained svm.
(bigger values of this may make training go faster but doesn't affect the
result. However, too big a value will cause you to run out of memory.)
- returns a probabilistic_decision_function that represents the trained
svm.
- The parameters of the probability model are estimated by performing k-fold - The parameters of the probability model are estimated by performing k-fold
cross validation. cross validation.
- The number of folds used is given by the folds argument. - The number of folds used is given by the folds argument.
throws
- any exceptions thrown by trainer.train()
- std::bad_alloc
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -134,46 +290,36 @@ namespace dlib ...@@ -134,46 +290,36 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename K typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
> >
const matrix<typename K::scalar_type, 1, 2, typename K::mem_manager_type> svm_nu_cross_validate ( const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
const typename decision_function<K>::sample_vector_type& x, cross_validate_trainer (
const typename decision_function<K>::scalar_vector_type& y, const trainer_type& trainer,
const K& kernel_function, const in_sample_vector_type& x,
const typename K::scalar_type nu, const in_scalar_vector_type& y,
const long folds, const long folds
const long cache_size = 200,
const typename K::scalar_type eps = 0.001
); );
/*! /*!
requires requires
- eps > 0 - is_binary_classification_problem(x,y) == true
- 1 < folds <= x.nr() - 1 < folds <= x.nr()
- x.nc() == 1 (i.e. x is a column vector) - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
- y.nc() == 1 (i.e. y is a column vector)
- x.nr() == y.nr()
- x.nr() > 1
- cache_size > 0
- for all valid i:
- y(i) == -1 or +1
- y(i) is the class that should be assigned to training example x(i)
- 0 < nu < maximum_nu(y)
- kernel_function == a kernel function object type as defined at the
top of dlib/svm/kernel_abstract.h
ensures ensures
- performs k-fold cross validation by training a nu-svm using the svm_nu_train() - performs k-fold cross validation by using the given trainer to solve the
function. Each fold is tested using the learned decision_function and the given binary classification problem for the given number of folds.
average accuracy from all folds is returned. The accuracy is returned in Each fold is tested using the output of the trainer and the average
a column vector, let us call it R. Both quantities in R are numbers between classification accuracy from all folds is returned.
0 and 1 which represent the fraction of examples correctly classified. R(0) - The accuracy is returned in a column vector, let us call it R. Both
is the fraction of +1 examples correctly classified and R(1) is the fraction quantities in R are numbers between 0 and 1 which represent the fraction
of -1 examples correctly classified. of examples correctly classified. R(0) is the fraction of +1 examples
correctly classified and R(1) is the fraction of -1 examples correctly
classified.
- The number of folds used is given by the folds argument. - The number of folds used is given by the folds argument.
- in each fold: trains a nu support vector classifier given the training samples throws
in x and labels in y. Training is done when the error is less than eps. - any exceptions thrown by trainer.train()
- caches approximately at most cache_size megabytes of the kernel matrix. - std::bad_alloc
(bigger values of this may make training go faster but doesn't affect the
result. However, too big a value will cause you to run out of memory.)
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment