"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "adfff4864327752c84a9ce662bd194c73ca4d075"
svm_rank_trainer.cpp 5.5 KB
Newer Older
1
2
// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
3

4
#include <dlib/python.h>
5
6
7
8
#include <boost/shared_ptr.hpp>
#include <dlib/matrix.h>
#include <dlib/svm.h>
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
9
10
#include "testing_results.h"
#include <boost/python/args.hpp>
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

using namespace dlib;
using namespace std;
using namespace boost::python;

typedef matrix<double,0,1> sample_type; 
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;


// ----------------------------------------------------------------------------------------

namespace dlib
{
    template <typename T>
    bool operator== (
        const ranking_pair<T>& ,
        const ranking_pair<T>& 
    )
    {
        pyassert(false, "It is illegal to compare ranking pair objects for equality.");
        return false;
    }
}

template <typename T>
void resize(T& v, unsigned long n) { v.resize(n); }

// ----------------------------------------------------------------------------------------

template <typename trainer_type>
typename trainer_type::trained_function_type train1 (
    const trainer_type& trainer,
    const ranking_pair<typename trainer_type::sample_type>& sample
)
{
    typedef ranking_pair<typename trainer_type::sample_type> st;
    pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs");
    return trainer.train(sample);
}

template <typename trainer_type>
typename trainer_type::trained_function_type train2 (
    const trainer_type& trainer,
    const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples
)
{
    pyassert(is_ranking_problem(samples), "Invalid inputs");
    return trainer.train(samples);
}

template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
    pyassert(eps > 0, "epsilon must be > 0");
    trainer.set_epsilon(eps);
}

template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }

template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c(C);
}

template <typename trainer_type>
double get_c (const trainer_type& trainer)
{
    return trainer.get_c();
}


template <typename trainer>
void add_ranker (
    const char* name
)
{
    class_<trainer>(name)
        .add_property("epsilon", get_epsilon<trainer>, set_epsilon<trainer>)
        .add_property("c", get_c<trainer>, set_c<trainer>)
        .add_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations)
        .add_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1)
        .add_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights)
        .def("train", train1<trainer>)
        .def("train", train2<trainer>)
        .def("be_verbose", &trainer::be_verbose)
        .def("be_quiet", &trainer::be_quiet);
}

// ----------------------------------------------------------------------------------------

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
template <
    typename trainer_type,
    typename T
    >
const ranking_test _cross_ranking_validate_trainer (
    const trainer_type& trainer,
    const std::vector<ranking_pair<T> >& samples,
    const unsigned long folds
)
{
    pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
    pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
    return cross_validate_ranking_trainer(trainer, samples, folds);
}

// ----------------------------------------------------------------------------------------

121
122
void bind_svm_rank_trainer()
{
123
    using boost::python::arg;
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    class_<ranking_pair<sample_type> >("ranking_pair")
        .add_property("relevant", &ranking_pair<sample_type>::relevant)
        .add_property("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
        .def_pickle(serialize_pickle<ranking_pair<sample_type> >());

    class_<ranking_pair<sparse_vect> >("sparse_ranking_pair")
        .add_property("relevant", &ranking_pair<sparse_vect>::relevant)
        .add_property("nonrelevant", &ranking_pair<sparse_vect>::nonrelevant)
        .def_pickle(serialize_pickle<ranking_pair<sparse_vect> >());

    typedef std::vector<ranking_pair<sample_type> > ranking_pairs;
    class_<ranking_pairs>("ranking_pairs")
        .def(vector_indexing_suite<ranking_pairs>())
        .def("clear", &ranking_pairs::clear)
        .def("resize", resize<ranking_pairs>)
        .def_pickle(serialize_pickle<ranking_pairs>());

    typedef std::vector<ranking_pair<sparse_vect> > sparse_ranking_pairs;
    class_<sparse_ranking_pairs>("sparse_ranking_pairs")
        .def(vector_indexing_suite<sparse_ranking_pairs>())
        .def("clear", &sparse_ranking_pairs::clear)
        .def("resize", resize<sparse_ranking_pairs>)
        .def_pickle(serialize_pickle<sparse_ranking_pairs>());

    add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >("svm_rank_trainer");
    add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >("svm_rank_trainer_sparse");
150
151
152
153
154
155
156

    def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
                svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
                (arg("trainer"), arg("samples"), arg("folds")) );
    def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
                svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
                (arg("trainer"), arg("samples"), arg("folds")) );
157
158
159
160
}