"src/array/vscode:/vscode.git/clone" did not exist on "3e72c53ae0ecd16ae057aa63e81f0c57f5828fa8"
Commit 2f34594f authored by Davis King's avatar Davis King
Browse files

Added cross validation functions for ranking tools and slightly improved documentation

for other cross validation functions.
parent 97f82b1e
......@@ -6,6 +6,7 @@
#include "serialize_pickle.h"
#include <dlib/svm_threaded.h>
#include "pyassert.h"
#include <boost/python/args.hpp>
using namespace dlib;
using namespace std;
......@@ -166,34 +167,43 @@ const binary_test _cross_validate_trainer_t (
void bind_svm_c_trainer()
{
using boost::python::arg;
{
typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer2<T>("svm_c_trainer_radial_basis")
.add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer2<T>("svm_c_trainer_sparse_radial_basis")
.add_property("gamma", get_gamma_sparse, set_gamma_sparse);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer2<T>("svm_c_trainer_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer2<T>("svm_c_trainer_sparse_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
......@@ -205,8 +215,10 @@ void bind_svm_c_trainer()
.def("be_verbose", &T::be_verbose)
.def("be_quiet", &T::be_quiet);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
......@@ -218,8 +230,10 @@ void bind_svm_c_trainer()
.def("be_verbose", &T::be_verbose)
.def("be_quiet", &T::be_quiet);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
}
......
......@@ -6,6 +6,8 @@
#include <dlib/svm.h>
#include "pyassert.h"
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include "testing_results.h"
#include <boost/python/args.hpp>
using namespace dlib;
using namespace std;
......@@ -99,8 +101,26 @@ void add_ranker (
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename T
>
const ranking_test _cross_ranking_validate_trainer (
const trainer_type& trainer,
const std::vector<ranking_pair<T> >& samples,
const unsigned long folds
)
{
pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
return cross_validate_ranking_trainer(trainer, samples, folds);
}
// ----------------------------------------------------------------------------------------
void bind_svm_rank_trainer()
{
using boost::python::arg;
class_<ranking_pair<sample_type> >("ranking_pair")
.add_property("relevant", &ranking_pair<sample_type>::relevant)
.add_property("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
......@@ -127,6 +147,13 @@ void bind_svm_rank_trainer()
add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >("svm_rank_trainer");
add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >("svm_rank_trainer_sparse");
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
(arg("trainer"), arg("samples"), arg("folds")) );
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
(arg("trainer"), arg("samples"), arg("folds")) );
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment