Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
859ccf5e
Commit
859ccf5e
authored
Apr 27, 2013
by
Davis King
Browse files
Added some cross validation wrappers.
parent
b8f2b522
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
98 additions
and
43 deletions
+98
-43
tools/python/src/decision_functions.cpp
tools/python/src/decision_functions.cpp
+1
-36
tools/python/src/other.cpp
tools/python/src/other.cpp
+11
-6
tools/python/src/svm_c_trainer.cpp
tools/python/src/svm_c_trainer.cpp
+43
-1
tools/python/src/testing_results.h
tools/python/src/testing_results.h
+43
-0
No files found.
tools/python/src/decision_functions.cpp
View file @
859ccf5e
#include "testing_results.h"
#include <boost/python.hpp>
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/shared_ptr.hpp>
#include "serialize_pickle.h"
#include "serialize_pickle.h"
...
@@ -100,18 +101,6 @@ void add_linear_df (
...
@@ -100,18 +101,6 @@ void add_linear_df (
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct
binary_test
{
binary_test
()
:
class1_accuracy
(
0
),
class2_accuracy
(
0
)
{}
binary_test
(
const
matrix
<
double
,
1
,
2
>&
m
)
:
class1_accuracy
(
m
(
0
)),
class2_accuracy
(
m
(
1
))
{}
double
class1_accuracy
;
double
class2_accuracy
;
};
std
::
string
binary_test__str__
(
const
binary_test
&
item
)
std
::
string
binary_test__str__
(
const
binary_test
&
item
)
{
{
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
...
@@ -120,18 +109,6 @@ std::string binary_test__str__(const binary_test& item)
...
@@ -120,18 +109,6 @@ std::string binary_test__str__(const binary_test& item)
}
}
std
::
string
binary_test__repr__
(
const
binary_test
&
item
)
{
return
"< "
+
binary_test__str__
(
item
)
+
" >"
;}
std
::
string
binary_test__repr__
(
const
binary_test
&
item
)
{
return
"< "
+
binary_test__str__
(
item
)
+
" >"
;}
struct
regression_test
{
regression_test
()
:
mean_squared_error
(
0
),
R_squared
(
0
)
{}
regression_test
(
const
matrix
<
double
,
1
,
2
>&
m
)
:
mean_squared_error
(
m
(
0
)),
R_squared
(
m
(
1
))
{}
double
mean_squared_error
;
double
R_squared
;
};
std
::
string
regression_test__str__
(
const
regression_test
&
item
)
std
::
string
regression_test__str__
(
const
regression_test
&
item
)
{
{
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
...
@@ -140,18 +117,6 @@ std::string regression_test__str__(const regression_test& item)
...
@@ -140,18 +117,6 @@ std::string regression_test__str__(const regression_test& item)
}
}
std
::
string
regression_test__repr__
(
const
regression_test
&
item
)
{
return
"< "
+
regression_test__str__
(
item
)
+
" >"
;}
std
::
string
regression_test__repr__
(
const
regression_test
&
item
)
{
return
"< "
+
regression_test__str__
(
item
)
+
" >"
;}
struct
ranking_test
{
ranking_test
()
:
ranking_accuracy
(
0
),
mean_ap
(
0
)
{}
ranking_test
(
const
matrix
<
double
,
1
,
2
>&
m
)
:
ranking_accuracy
(
m
(
0
)),
mean_ap
(
m
(
1
))
{}
double
ranking_accuracy
;
double
mean_ap
;
};
std
::
string
ranking_test__str__
(
const
ranking_test
&
item
)
std
::
string
ranking_test__str__
(
const
ranking_test
&
item
)
{
{
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
...
...
tools/python/src/other.cpp
View file @
859ccf5e
...
@@ -13,12 +13,17 @@ tuple get_training_data()
...
@@ -13,12 +13,17 @@ tuple get_training_data()
std
::
vector
<
double
>
labels
;
std
::
vector
<
double
>
labels
;
sample_type
samp
(
3
);
sample_type
samp
(
3
);
samp
=
1
,
2
,
3
;
samples
.
push_back
(
samp
);
for
(
int
i
=
0
;
i
<
10
;
++
i
)
labels
.
push_back
(
+
1
);
{
samp
=
-
1
,
-
2
,
-
3
;
samp
=
1
,
2
,
3
;
samples
.
push_back
(
samp
);
samples
.
push_back
(
samp
);
labels
.
push_back
(
-
1
);
labels
.
push_back
(
+
1
);
samp
=
-
1
,
-
2
,
-
3
;
samples
.
push_back
(
samp
);
labels
.
push_back
(
-
1
);
}
return
make_tuple
(
samples
,
labels
);
return
make_tuple
(
samples
,
labels
);
}
}
...
...
tools/python/src/svm_c_trainer.cpp
View file @
859ccf5e
#include "testing_results.h"
#include <boost/python.hpp>
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/shared_ptr.hpp>
#include <dlib/matrix.h>
#include <dlib/matrix.h>
#include "serialize_pickle.h"
#include "serialize_pickle.h"
#include <dlib/svm.h>
#include <dlib/svm
_threaded
.h>
#include "pyassert.h"
#include "pyassert.h"
using
namespace
dlib
;
using
namespace
dlib
;
...
@@ -118,6 +119,39 @@ double get_gamma_sparse (
...
@@ -118,6 +119,39 @@ double get_gamma_sparse (
return
trainer
.
get_kernel
().
gamma
;
return
trainer
.
get_kernel
().
gamma
;
}
}
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
>
const
binary_test
_cross_validate_trainer
(
const
trainer_type
&
trainer
,
const
std
::
vector
<
typename
trainer_type
::
sample_type
>&
x
,
const
std
::
vector
<
double
>&
y
,
const
long
folds
)
{
pyassert
(
is_binary_classification_problem
(
x
,
y
),
"Training data does not make a valid training set."
);
pyassert
(
1
<
folds
&&
folds
<=
x
.
size
(),
"Invalid number of folds given."
);
return
cross_validate_trainer
(
trainer
,
x
,
y
,
folds
);
}
template
<
typename
trainer_type
>
const
binary_test
_cross_validate_trainer_t
(
const
trainer_type
&
trainer
,
const
std
::
vector
<
typename
trainer_type
::
sample_type
>&
x
,
const
std
::
vector
<
double
>&
y
,
const
unsigned
long
folds
,
const
unsigned
long
num_threads
)
{
pyassert
(
is_binary_classification_problem
(
x
,
y
),
"Training data does not make a valid training set."
);
pyassert
(
1
<
folds
&&
folds
<=
x
.
size
(),
"Invalid number of folds given."
);
pyassert
(
1
<
num_threads
,
"The number of threads specified must not be zero."
);
return
cross_validate_trainer_threaded
(
trainer
,
x
,
y
,
folds
,
num_threads
);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
@@ -125,13 +159,21 @@ void bind_svm_c_trainer()
...
@@ -125,13 +159,21 @@ void bind_svm_c_trainer()
{
{
setup_trainer
<
svm_c_trainer
<
radial_basis_kernel
<
sample_type
>
>
>
(
"svm_c_trainer_radial_basis"
)
setup_trainer
<
svm_c_trainer
<
radial_basis_kernel
<
sample_type
>
>
>
(
"svm_c_trainer_radial_basis"
)
.
add_property
(
"gamma"
,
get_gamma
,
set_gamma
);
.
add_property
(
"gamma"
,
get_gamma
,
set_gamma
);
def
(
"cross_validate_trainer"
,
_cross_validate_trainer
<
svm_c_trainer
<
radial_basis_kernel
<
sample_type
>
>
>
);
def
(
"cross_validate_trainer_threaded"
,
_cross_validate_trainer_t
<
svm_c_trainer
<
radial_basis_kernel
<
sample_type
>
>
>
);
setup_trainer
<
svm_c_trainer
<
sparse_radial_basis_kernel
<
sparse_vect
>
>
>
(
"svm_c_trainer_sparse_radial_basis"
)
setup_trainer
<
svm_c_trainer
<
sparse_radial_basis_kernel
<
sparse_vect
>
>
>
(
"svm_c_trainer_sparse_radial_basis"
)
.
add_property
(
"gamma"
,
get_gamma
,
set_gamma
);
.
add_property
(
"gamma"
,
get_gamma
,
set_gamma
);
def
(
"cross_validate_trainer"
,
_cross_validate_trainer
<
svm_c_trainer
<
sparse_radial_basis_kernel
<
sparse_vect
>
>
>
);
def
(
"cross_validate_trainer_threaded"
,
_cross_validate_trainer_t
<
svm_c_trainer
<
sparse_radial_basis_kernel
<
sparse_vect
>
>
>
);
setup_trainer
<
svm_c_trainer
<
histogram_intersection_kernel
<
sample_type
>
>
>
(
"svm_c_trainer_histogram_intersection"
);
setup_trainer
<
svm_c_trainer
<
histogram_intersection_kernel
<
sample_type
>
>
>
(
"svm_c_trainer_histogram_intersection"
);
def
(
"cross_validate_trainer"
,
_cross_validate_trainer
<
svm_c_trainer
<
histogram_intersection_kernel
<
sample_type
>
>
>
);
def
(
"cross_validate_trainer_threaded"
,
_cross_validate_trainer_t
<
svm_c_trainer
<
histogram_intersection_kernel
<
sample_type
>
>
>
);
setup_trainer
<
svm_c_trainer
<
sparse_histogram_intersection_kernel
<
sparse_vect
>
>
>
(
"svm_c_trainer_sparse_histogram_intersection"
);
setup_trainer
<
svm_c_trainer
<
sparse_histogram_intersection_kernel
<
sparse_vect
>
>
>
(
"svm_c_trainer_sparse_histogram_intersection"
);
def
(
"cross_validate_trainer"
,
_cross_validate_trainer
<
svm_c_trainer
<
sparse_histogram_intersection_kernel
<
sparse_vect
>
>
>
);
def
(
"cross_validate_trainer_threaded"
,
_cross_validate_trainer_t
<
svm_c_trainer
<
sparse_histogram_intersection_kernel
<
sparse_vect
>
>
>
);
}
}
tools/python/src/testing_results.h
0 → 100644
View file @
859ccf5e
#ifndef DLIB_TESTING_ReSULTS_H__
#define DLIB_TESTING_ReSULTS_H__
#include <dlib/matrix.h>
struct
binary_test
{
binary_test
()
:
class1_accuracy
(
0
),
class2_accuracy
(
0
)
{}
binary_test
(
const
dlib
::
matrix
<
double
,
1
,
2
>&
m
)
:
class1_accuracy
(
m
(
0
)),
class2_accuracy
(
m
(
1
))
{}
double
class1_accuracy
;
double
class2_accuracy
;
};
struct
regression_test
{
regression_test
()
:
mean_squared_error
(
0
),
R_squared
(
0
)
{}
regression_test
(
const
dlib
::
matrix
<
double
,
1
,
2
>&
m
)
:
mean_squared_error
(
m
(
0
)),
R_squared
(
m
(
1
))
{}
double
mean_squared_error
;
double
R_squared
;
};
struct
ranking_test
{
ranking_test
()
:
ranking_accuracy
(
0
),
mean_ap
(
0
)
{}
ranking_test
(
const
dlib
::
matrix
<
double
,
1
,
2
>&
m
)
:
ranking_accuracy
(
m
(
0
)),
mean_ap
(
m
(
1
))
{}
double
ranking_accuracy
;
double
mean_ap
;
};
#endif // DLIB_TESTING_ReSULTS_H__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment