Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
754610f2
Commit
754610f2
authored
May 04, 2014
by
Davis King
Browse files
Added the option to set a prior to svm_rank_trainer.
parent
461abe65
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
1 deletion
+129
-1
dlib/svm/svm_rank_trainer.h
dlib/svm/svm_rank_trainer.h
+51
-1
dlib/svm/svm_rank_trainer_abstract.h
dlib/svm/svm_rank_trainer_abstract.h
+43
-0
dlib/test/ranking.cpp
dlib/test/ranking.cpp
+35
-0
No files found.
dlib/svm/svm_rank_trainer.h
View file @
754610f2
...
@@ -297,6 +297,8 @@ namespace dlib
...
@@ -297,6 +297,8 @@ namespace dlib
)
)
{
{
last_weight_1
=
should_last_weight_be_1
;
last_weight_1
=
should_last_weight_be_1
;
if
(
last_weight_1
)
prior
.
set_size
(
0
);
}
}
void
set_oca
(
void
set_oca
(
...
@@ -326,6 +328,33 @@ namespace dlib
...
@@ -326,6 +328,33 @@ namespace dlib
)
)
{
{
learn_nonnegative_weights
=
value
;
learn_nonnegative_weights
=
value
;
if
(
learn_nonnegative_weights
)
prior
.
set_size
(
0
);
}
void
set_prior
(
const
trained_function_type
&
prior_
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
prior_
.
basis_vectors
.
size
()
==
1
&&
prior_
.
alpha
(
0
)
==
1
,
"
\t
void svm_rank_trainer::set_prior()"
<<
"
\n\t
The supplied prior could not have been created by this object's train() method."
<<
"
\n\t
prior_.basis_vectors.size(): "
<<
prior_
.
basis_vectors
.
size
()
<<
"
\n\t
prior_.alpha(0): "
<<
prior_
.
alpha
(
0
)
<<
"
\n\t
this: "
<<
this
);
prior
=
prior_
.
basis_vectors
(
0
);
learn_nonnegative_weights
=
false
;
last_weight_1
=
false
;
}
bool
has_prior
(
)
const
{
return
prior
.
size
()
!=
0
;
}
}
void
set_c
(
void
set_c
(
...
@@ -379,10 +408,30 @@ namespace dlib
...
@@ -379,10 +408,30 @@ namespace dlib
force_weight_1_idx
=
num_dims
-
1
;
force_weight_1_idx
=
num_dims
-
1
;
}
}
if
(
has_prior
())
{
if
(
is_matrix
<
sample_type
>::
value
)
{
// make sure requires clause is not broken
DLIB_CASSERT
(
num_dims
==
(
unsigned
long
)
prior
.
size
(),
"
\t
decision_function svm_rank_trainer::train(samples)"
<<
"
\n\t
The dimension of the training vectors must match the dimension of
\n
"
<<
"
\n\t
those used to create the prior."
<<
"
\n\t
num_dims: "
<<
num_dims
<<
"
\n\t
prior.size(): "
<<
prior
.
size
()
);
}
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
w
,
prior
);
}
else
{
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
w
,
w
,
num_nonnegative
,
num_nonnegative
,
force_weight_1_idx
);
force_weight_1_idx
);
}
// put the solution into a decision function and then return it
// put the solution into a decision function and then return it
...
@@ -415,6 +464,7 @@ namespace dlib
...
@@ -415,6 +464,7 @@ namespace dlib
unsigned
long
max_iterations
;
unsigned
long
max_iterations
;
bool
learn_nonnegative_weights
;
bool
learn_nonnegative_weights
;
bool
last_weight_1
;
bool
last_weight_1
;
matrix
<
scalar_type
,
0
,
1
>
prior
;
};
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/svm_rank_trainer_abstract.h
View file @
754610f2
...
@@ -58,6 +58,7 @@ namespace dlib
...
@@ -58,6 +58,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
- #forces_last_weight_to_1() == false
- #has_prior() == false
!*/
!*/
explicit
svm_rank_trainer
(
explicit
svm_rank_trainer
(
...
@@ -76,6 +77,7 @@ namespace dlib
...
@@ -76,6 +77,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
- #forces_last_weight_to_1() == false
- #has_prior() == false
!*/
!*/
void
set_epsilon
(
void
set_epsilon
(
...
@@ -146,6 +148,8 @@ namespace dlib
...
@@ -146,6 +148,8 @@ namespace dlib
/*!
/*!
ensures
ensures
- #forces_last_weight_to_1() == should_last_weight_be_1
- #forces_last_weight_to_1() == should_last_weight_be_1
- if (should_last_weight_be_1 == true) then
- #has_prior() == false
!*/
!*/
void
set_oca
(
void
set_oca
(
...
@@ -190,6 +194,39 @@ namespace dlib
...
@@ -190,6 +194,39 @@ namespace dlib
/*!
/*!
ensures
ensures
- #learns_nonnegative_weights() == value
- #learns_nonnegative_weights() == value
- if (value == true) then
- #has_prior() == false
!*/
void
set_prior
(
const
trained_function_type
&
prior
);
/*!
requires
- prior == a function produced by a call to this class's train() function.
Therefore, it must be the case that:
- prior.basis_vectors.size() == 1
- prior.alpha(0) == 1
ensures
- Subsequent calls to train() will try to learn a function similar to the
given prior.
- #has_prior() == true
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
!*/
bool
has_prior
(
)
const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
!*/
void
set_c
(
void
set_c
(
...
@@ -219,6 +256,9 @@ namespace dlib
...
@@ -219,6 +256,9 @@ namespace dlib
/*!
/*!
requires
requires
- is_ranking_problem(samples) == true
- is_ranking_problem(samples) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- trains a ranking support vector classifier given the training samples.
- trains a ranking support vector classifier given the training samples.
- returns a decision function F with the following properties:
- returns a decision function F with the following properties:
...
@@ -237,6 +277,9 @@ namespace dlib
...
@@ -237,6 +277,9 @@ namespace dlib
/*!
/*!
requires
requires
- is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
- is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- This is just a convenience routine for calling the above train()
- This is just a convenience routine for calling the above train()
function. That is, it just copies sample into a std::vector object and
function. That is, it just copies sample into a std::vector object and
...
...
dlib/test/ranking.cpp
View file @
754610f2
...
@@ -73,6 +73,40 @@ namespace
...
@@ -73,6 +73,40 @@ namespace
}
}
}
}
// ----------------------------------------------------------------------------------------
void
run_prior_test
()
{
print_spinner
();
typedef
matrix
<
double
,
3
,
1
>
sample_type
;
typedef
linear_kernel
<
sample_type
>
kernel_type
;
svm_rank_trainer
<
kernel_type
>
trainer
;
ranking_pair
<
sample_type
>
data
;
sample_type
samp
;
samp
=
0
,
0
,
1
;
data
.
relevant
.
push_back
(
samp
);
samp
=
0
,
1
,
0
;
data
.
nonrelevant
.
push_back
(
samp
);
trainer
.
set_c
(
10
);
decision_function
<
kernel_type
>
df
=
trainer
.
train
(
data
);
trainer
.
set_prior
(
df
);
data
.
relevant
.
clear
();
data
.
nonrelevant
.
clear
();
samp
=
1
,
0
,
0
;
data
.
relevant
.
push_back
(
samp
);
samp
=
0
,
1
,
0
;
data
.
nonrelevant
.
push_back
(
samp
);
df
=
trainer
.
train
(
data
);
dlog
<<
LINFO
<<
trans
(
df
.
basis_vectors
(
0
));
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
0
)
>
0
);
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
1
)
<
0
);
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
2
)
>
0
);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
dotest1
()
void
dotest1
()
...
@@ -355,6 +389,7 @@ namespace
...
@@ -355,6 +389,7 @@ namespace
dotest_sparse_vectors
();
dotest_sparse_vectors
();
test_svmrank_weight_force_dense
<
true
>
();
test_svmrank_weight_force_dense
<
true
>
();
test_svmrank_weight_force_dense
<
false
>
();
test_svmrank_weight_force_dense
<
false
>
();
run_prior_test
();
}
}
}
a
;
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment