Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
d9d6fa12
Commit
d9d6fa12
authored
May 01, 2014
by
Davis King
Browse files
Added the ability to set a previously trained function as a prior to the
svm_multiclass_linear_trainer.
parent
a7047b35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
164 additions
and
8 deletions
+164
-8
dlib/svm/svm_multiclass_linear_trainer.h
dlib/svm/svm_multiclass_linear_trainer.h
+73
-8
dlib/svm/svm_multiclass_linear_trainer_abstract.h
dlib/svm/svm_multiclass_linear_trainer_abstract.h
+32
-0
dlib/test/svm_multiclass_linear.cpp
dlib/test/svm_multiclass_linear.cpp
+59
-0
No files found.
dlib/svm/svm_multiclass_linear_trainer.h
View file @
d9d6fa12
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "../matrix.h"
#include "../matrix.h"
#include "sparse_vector.h"
#include "sparse_vector.h"
#include "function.h"
#include "function.h"
#include <algorithm>
namespace
dlib
namespace
dlib
{
{
...
@@ -46,13 +47,15 @@ namespace dlib
...
@@ -46,13 +47,15 @@ namespace dlib
multiclass_svm_problem
(
multiclass_svm_problem
(
const
std
::
vector
<
sample_type
>&
samples_
,
const
std
::
vector
<
sample_type
>&
samples_
,
const
std
::
vector
<
label_type
>&
labels_
,
const
std
::
vector
<
label_type
>&
labels_
,
const
std
::
vector
<
label_type
>&
distinct_labels_
,
const
unsigned
long
dims_
,
const
unsigned
long
num_threads
const
unsigned
long
num_threads
)
:
)
:
structural_svm_problem_threaded
<
matrix_type
,
std
::
vector
<
std
::
pair
<
unsigned
long
,
typename
matrix_type
::
type
>
>
>
(
num_threads
),
structural_svm_problem_threaded
<
matrix_type
,
std
::
vector
<
std
::
pair
<
unsigned
long
,
typename
matrix_type
::
type
>
>
>
(
num_threads
),
samples
(
samples_
),
samples
(
samples_
),
labels
(
labels_
),
labels
(
labels_
),
distinct_labels
(
select_all_
distinct_labels
(
labels_
)
),
distinct_labels
(
distinct_labels
_
),
dims
(
max_index_plus_one
(
sample
s_
)
+
1
)
// +1 for the bias
dims
(
dim
s_
+
1
)
// +1 for the bias
{}
{}
virtual
long
get_num_dimensions
(
virtual
long
get_num_dimensions
(
...
@@ -151,7 +154,7 @@ namespace dlib
...
@@ -151,7 +154,7 @@ namespace dlib
const
std
::
vector
<
sample_type
>&
samples
;
const
std
::
vector
<
sample_type
>&
samples
;
const
std
::
vector
<
label_type
>&
labels
;
const
std
::
vector
<
label_type
>&
labels
;
const
std
::
vector
<
label_type
>
distinct_labels
;
const
std
::
vector
<
label_type
>
&
distinct_labels
;
const
long
dims
;
const
long
dims
;
};
};
...
@@ -260,6 +263,7 @@ namespace dlib
...
@@ -260,6 +263,7 @@ namespace dlib
)
)
{
{
learn_nonnegative_weights
=
value
;
learn_nonnegative_weights
=
value
;
prior
=
trained_function_type
();
}
}
void
set_c
(
void
set_c
(
...
@@ -283,6 +287,20 @@ namespace dlib
...
@@ -283,6 +287,20 @@ namespace dlib
return
C
;
return
C
;
}
}
void
set_prior
(
const
trained_function_type
&
prior_
)
{
prior
=
prior_
;
learn_nonnegative_weights
=
false
;
}
bool
has_prior
(
)
const
{
return
prior
.
labels
.
size
()
!=
0
;
}
trained_function_type
train
(
trained_function_type
train
(
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
label_type
>&
all_labels
const
std
::
vector
<
label_type
>&
all_labels
...
@@ -306,9 +324,33 @@ namespace dlib
...
@@ -306,9 +324,33 @@ namespace dlib
<<
"
\n\t
all_labels.size(): "
<<
all_labels
.
size
()
<<
"
\n\t
all_labels.size(): "
<<
all_labels
.
size
()
);
);
trained_function_type
df
;
df
.
labels
=
select_all_distinct_labels
(
all_labels
);
if
(
has_prior
())
{
df
.
labels
.
insert
(
df
.
labels
.
end
(),
prior
.
labels
.
begin
(),
prior
.
labels
.
end
());
df
.
labels
=
select_all_distinct_labels
(
df
.
labels
);
}
const
long
input_sample_dimensionality
=
max_index_plus_one
(
all_samples
);
// If the samples are sparse then the right thing to do is to take the max
// dimensionality between the prior and the new samples. But if the samples
// are dense vectors then they definitely all have to have exactly the same
// dimensionality.
const
long
dims
=
std
::
max
(
df
.
weights
.
nc
(),
input_sample_dimensionality
);
if
(
is_matrix
<
sample_type
>::
value
&&
has_prior
())
{
DLIB_ASSERT
(
input_sample_dimensionality
==
prior
.
weights
.
nc
(),
"
\t
trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)"
<<
"
\n\t
The training samples given to this function are not the same kind of training "
<<
"
\n\t
samples used to create the prior."
<<
"
\n\t
input_sample_dimensionality: "
<<
input_sample_dimensionality
<<
"
\n\t
prior.weights.nc(): "
<<
prior
.
weights
.
nc
()
);
}
typedef
matrix
<
scalar_type
,
0
,
1
>
w_type
;
typedef
matrix
<
scalar_type
,
0
,
1
>
w_type
;
w_type
weights
;
w_type
weights
;
multiclass_svm_problem
<
w_type
,
sample_type
,
label_type
>
problem
(
all_samples
,
all_labels
,
num_threads
);
multiclass_svm_problem
<
w_type
,
sample_type
,
label_type
>
problem
(
all_samples
,
all_labels
,
df
.
labels
,
dims
,
num_threads
);
if
(
verbose
)
if
(
verbose
)
problem
.
be_verbose
();
problem
.
be_verbose
();
...
@@ -322,12 +364,33 @@ namespace dlib
...
@@ -322,12 +364,33 @@ namespace dlib
num_nonnegative
=
problem
.
get_num_dimensions
();
num_nonnegative
=
problem
.
get_num_dimensions
();
}
}
svm_objective
=
solver
(
problem
,
weights
,
num_nonnegative
);
if
(
!
has_prior
())
{
svm_objective
=
solver
(
problem
,
weights
,
num_nonnegative
);
}
else
{
matrix
<
scalar_type
>
temp
(
df
.
labels
.
size
(),
dims
);
w_type
b
(
df
.
labels
.
size
());
temp
=
0
;
b
=
0
;
// Copy the prior into the temp and b matrices. We have to do this row
// by row copy because the new training data might have new labels we
// haven't seen before and therefore the sizes of these matrices could be
// different.
for
(
unsigned
long
i
=
0
;
i
<
prior
.
labels
.
size
();
++
i
)
{
const
long
r
=
std
::
find
(
df
.
labels
.
begin
(),
df
.
labels
.
end
(),
prior
.
labels
[
i
])
-
df
.
labels
.
begin
();
set_rowm
(
temp
,
r
)
=
rowm
(
prior
.
weights
,
i
);
b
(
r
)
=
prior
.
b
(
i
);
}
const
w_type
prior_vect
=
reshape_to_column_vector
(
join_rows
(
temp
,
b
));
svm_objective
=
solver
(
problem
,
weights
,
prior_vect
);
}
trained_function_type
df
;
const
long
dims
=
max_index_plus_one
(
all_samples
);
df
.
labels
=
select_all_distinct_labels
(
all_labels
);
df
.
weights
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
range
(
0
,
dims
-
1
));
df
.
weights
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
range
(
0
,
dims
-
1
));
df
.
b
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
dims
);
df
.
b
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
dims
);
return
df
;
return
df
;
...
@@ -341,6 +404,8 @@ namespace dlib
...
@@ -341,6 +404,8 @@ namespace dlib
bool
verbose
;
bool
verbose
;
oca
solver
;
oca
solver
;
bool
learn_nonnegative_weights
;
bool
learn_nonnegative_weights
;
trained_function_type
prior
;
};
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/svm_multiclass_linear_trainer_abstract.h
View file @
d9d6fa12
...
@@ -37,6 +37,7 @@ namespace dlib
...
@@ -37,6 +37,7 @@ namespace dlib
- get_c() == 1
- get_c() == 1
- this object will not be verbose unless be_verbose() is called
- this object will not be verbose unless be_verbose() is called
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
- has_prior() == false
WHAT THIS OBJECT REPRESENTS
WHAT THIS OBJECT REPRESENTS
This object represents a tool for training a multiclass support
This object represents a tool for training a multiclass support
...
@@ -176,6 +177,29 @@ namespace dlib
...
@@ -176,6 +177,29 @@ namespace dlib
- #learns_nonnegative_weights() == value
- #learns_nonnegative_weights() == value
!*/
!*/
void
set_prior
(
const
trained_function_type
&
prior
);
/*!
ensures
- #has_prior() == true
- #learns_nonnegative_weights() == false
!*/
bool
has_prior
(
)
const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
trained_function_type
train
(
trained_function_type
train
(
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
label_type
>&
all_labels
const
std
::
vector
<
label_type
>&
all_labels
...
@@ -183,6 +207,10 @@ namespace dlib
...
@@ -183,6 +207,10 @@ namespace dlib
/*!
/*!
requires
requires
- is_learning_problem(all_samples, all_labels)
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
- returns a multiclass_linear_decision_function F with the following properties:
...
@@ -200,6 +228,10 @@ namespace dlib
...
@@ -200,6 +228,10 @@ namespace dlib
/*!
/*!
requires
requires
- is_learning_problem(all_samples, all_labels)
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
- returns a multiclass_linear_decision_function F with the following properties:
...
...
dlib/test/svm_multiclass_linear.cpp
View file @
d9d6fa12
...
@@ -35,6 +35,63 @@ namespace
...
@@ -35,6 +35,63 @@ namespace
}
}
void
test_prior
()
{
print_spinner
();
typedef
matrix
<
double
,
4
,
1
>
sample_type
;
typedef
linear_kernel
<
sample_type
>
kernel_type
;
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
int
>
labels
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
if
(
i
==
2
)
++
i
;
for
(
int
iter
=
0
;
iter
<
5
;
++
iter
)
{
sample_type
samp
;
samp
=
0
;
samp
(
i
)
=
1
;
samples
.
push_back
(
samp
);
labels
.
push_back
(
i
);
}
}
svm_multiclass_linear_trainer
<
kernel_type
,
int
>
trainer
;
multiclass_linear_decision_function
<
kernel_type
,
int
>
df
=
trainer
.
train
(
samples
,
labels
);
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
//cout << df.weights << endl;
//cout << df.b << endl;
std
::
vector
<
sample_type
>
samples2
;
std
::
vector
<
int
>
labels2
;
int
i
=
2
;
for
(
int
iter
=
0
;
iter
<
5
;
++
iter
)
{
sample_type
samp
;
samp
=
0
;
samp
(
i
)
=
1
;
samples2
.
push_back
(
samp
);
labels2
.
push_back
(
i
);
samples
.
push_back
(
samp
);
labels
.
push_back
(
i
);
}
trainer
.
set_prior
(
df
);
trainer
.
set_c
(
0.1
);
df
=
trainer
.
train
(
samples2
,
labels2
);
matrix
<
double
>
res
=
test_multiclass_decision_function
(
df
,
samples
,
labels
);
dlog
<<
LINFO
<<
"test:
\n
"
<<
res
;
dlog
<<
LINFO
<<
df
.
weights
;
dlog
<<
LINFO
<<
df
.
b
;
DLIB_TEST
((
unsigned
int
)
sum
(
diag
(
res
))
==
samples
.
size
());
}
template
<
typename
sample_type
>
template
<
typename
sample_type
>
void
run_test
()
void
run_test
()
{
{
...
@@ -99,6 +156,8 @@ namespace
...
@@ -99,6 +156,8 @@ namespace
run_test
<
std
::
map
<
unsigned
int
,
float
>
>
();
run_test
<
std
::
map
<
unsigned
int
,
float
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
int
,
float
>
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
int
,
float
>
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
long
,
double
>
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
long
,
double
>
>
>
();
test_prior
();
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment