Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
a0c3c224
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e03637176d90cc5e298e13dfd5e583b2989b3aee"
Commit
a0c3c224
authored
Apr 30, 2012
by
Davis King
Browse files
Added validation functions for graph labeling problems.
parent
aa8f3f2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
0 deletions
+177
-0
dlib/svm/cross_validate_graph_labeling_trainer.h
dlib/svm/cross_validate_graph_labeling_trainer.h
+175
-0
dlib/svm/cross_validate_graph_labeling_trainer_abstract.h
dlib/svm/cross_validate_graph_labeling_trainer_abstract.h
+2
-0
No files found.
dlib/svm/cross_validate_graph_labeling_trainer.h
0 → 100644
View file @
a0c3c224
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_H__
#define DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_H__
#include "../array.h"
#include "../graph_cuts/min_cut.h"
#include "svm.h"
#include "cross_validate_graph_labeling_trainer_abstract.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
template
<
typename
graph_labeler
,
typename
graph_type
>
matrix
<
double
,
1
,
2
>
test_graph_labeling_function
(
const
graph_labeler
&
labeler
,
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
std
::
vector
<
node_label
>
>&
labels
)
{
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
)
,
"
\t
matrix test_graph_labeling_function()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
is_graph_labeling_problem(samples,labels): "
<<
is_graph_labeling_problem
(
samples
,
labels
)
<<
"
\n\t
is_learning_problem(samples,labels): "
<<
is_learning_problem
(
samples
,
labels
)
);
std
::
vector
<
node_label
>
temp
;
unsigned
long
num_pos_correct
=
0
;
unsigned
long
num_pos
=
0
;
unsigned
long
num_neg_correct
=
0
;
unsigned
long
num_neg
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
samples
.
size
();
++
i
)
{
labeler
(
samples
[
i
],
temp
);
for
(
unsigned
long
j
=
0
;
j
<
labels
[
i
].
size
();
++
j
)
{
if
(
labels
[
i
][
j
])
{
++
num_pos
;
if
(
temp
[
j
])
++
num_pos_correct
;
}
else
{
++
num_neg
;
if
(
!
temp
[
j
])
++
num_neg_correct
;
}
}
}
matrix
<
double
,
1
,
2
>
res
;
res
(
0
)
=
(
double
)
num_pos_correct
/
(
double
)(
num_pos
);
res
(
1
)
=
(
double
)
num_neg_correct
/
(
double
)(
num_neg
);
return
res
;
}
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
,
typename
graph_type
>
matrix
<
double
,
1
,
2
>
cross_validate_graph_labeling_trainer
(
const
trainer_type
&
trainer
,
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
std
::
vector
<
node_label
>
>&
labels
,
const
long
folds
)
{
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
)
&&
1
<
folds
&&
folds
<=
static_cast
<
long
>
(
samples
.
size
()),
"
\t
matrix cross_validate_graph_labeling_trainer()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
folds: "
<<
folds
<<
"
\n\t
is_graph_labeling_problem(samples,labels): "
<<
is_graph_labeling_problem
(
samples
,
labels
)
<<
"
\n\t
is_learning_problem(samples,labels): "
<<
is_learning_problem
(
samples
,
labels
)
);
typedef
std
::
vector
<
node_label
>
label_type
;
const
long
num_in_test
=
samples
.
size
()
/
folds
;
const
long
num_in_train
=
samples
.
size
()
-
num_in_test
;
dlib
::
array
<
graph_type
>
samples_test
,
samples_train
;
std
::
vector
<
label_type
>
labels_test
,
labels_train
;
long
next_test_idx
=
0
;
std
::
vector
<
node_label
>
temp
;
unsigned
long
num_pos_correct
=
0
;
unsigned
long
num_pos
=
0
;
unsigned
long
num_neg_correct
=
0
;
unsigned
long
num_neg
=
0
;
graph_type
gtemp
;
for
(
long
i
=
0
;
i
<
folds
;
++
i
)
{
samples_test
.
clear
();
labels_test
.
clear
();
samples_train
.
clear
();
labels_train
.
clear
();
// load up the test samples
for
(
long
cnt
=
0
;
cnt
<
num_in_test
;
++
cnt
)
{
copy_graph
(
samples
[
next_test_idx
],
gtemp
);
samples_test
.
push_back
(
gtemp
);
labels_test
.
push_back
(
labels
[
next_test_idx
]);
next_test_idx
=
(
next_test_idx
+
1
)
%
samples
.
size
();
}
// load up the training samples
long
next
=
next_test_idx
;
for
(
long
cnt
=
0
;
cnt
<
num_in_train
;
++
cnt
)
{
copy_graph
(
samples
[
next
],
gtemp
);
samples_train
.
push_back
(
gtemp
);
labels_train
.
push_back
(
labels
[
next
]);
next
=
(
next
+
1
)
%
samples
.
size
();
}
const
typename
trainer_type
::
trained_function_type
&
labeler
=
trainer
.
train
(
samples_train
,
labels_train
);
// check how good labeler is on the test data
for
(
unsigned
long
i
=
0
;
i
<
samples_test
.
size
();
++
i
)
{
labeler
(
samples_test
[
i
],
temp
);
for
(
unsigned
long
j
=
0
;
j
<
labels_test
[
i
].
size
();
++
j
)
{
if
(
labels_test
[
i
][
j
])
{
++
num_pos
;
if
(
temp
[
j
])
++
num_pos_correct
;
}
else
{
++
num_neg
;
if
(
!
temp
[
j
])
++
num_neg_correct
;
}
}
}
}
// for (long i = 0; i < folds; ++i)
matrix
<
double
,
1
,
2
>
res
;
res
(
0
)
=
(
double
)
num_pos_correct
/
(
double
)(
num_pos
);
res
(
1
)
=
(
double
)
num_neg_correct
/
(
double
)(
num_neg
);
return
res
;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_H__
dlib/svm/cross_validate_graph_labeling_trainer_abstract.h
0 → 100644
View file @
a0c3c224
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment