Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
c23ef609
Commit
c23ef609
authored
Nov 04, 2011
by
Davis King
Browse files
added more tests
parent
f18acdf8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
137 additions
and
77 deletions
+137
-77
dlib/test/sequence_labeler.cpp
dlib/test/sequence_labeler.cpp
+137
-77
No files found.
dlib/test/sequence_labeler.cpp
View file @
c23ef609
...
...
@@ -62,8 +62,58 @@ namespace
}
};
bool
called_rejct_labeling
=
false
;
class
feature_extractor2
{
public:
typedef
unsigned
long
sample_type
;
unsigned
long
num_features
()
const
{
return
num_label_states
*
num_label_states
+
num_label_states
*
num_sample_states
;
}
unsigned
long
order
()
const
{
return
1
;
}
unsigned
long
num_labels
()
const
{
return
num_label_states
;
}
template
<
typename
EXP
>
bool
reject_labeling
(
const
std
::
vector
<
sample_type
>&
x
,
const
matrix_exp
<
EXP
>&
y
,
unsigned
long
position
)
const
{
called_rejct_labeling
=
true
;
return
false
;
}
template
<
typename
feature_setter
,
typename
EXP
>
void
get_features
(
feature_setter
&
set_feature
,
const
std
::
vector
<
sample_type
>&
x
,
const
matrix_exp
<
EXP
>&
y
,
unsigned
long
position
)
const
{
if
(
y
.
size
()
>
1
)
set_feature
(
y
(
1
)
*
num_label_states
+
y
(
0
));
set_feature
(
num_label_states
*
num_label_states
+
y
(
0
)
*
num_sample_states
+
x
[
position
]);
}
};
void
serialize
(
const
feature_extractor
&
,
std
::
ostream
&
)
{}
void
deserialize
(
feature_extractor
&
,
std
::
istream
&
)
{}
void
serialize
(
const
feature_extractor2
&
,
std
::
ostream
&
)
{}
void
deserialize
(
feature_extractor2
&
,
std
::
istream
&
)
{}
// ----------------------------------------------------------------------------------------
...
...
@@ -174,114 +224,124 @@ namespace
// ----------------------------------------------------------------------------------------
class
sequence_labeler_tester
:
public
tester
template
<
typename
fe_type
>
void
do_test
()
{
public:
sequence_labeler_tester
(
)
:
tester
(
"test_sequence_labeler"
,
"Runs tests on the sequence labeling code."
)
{}
matrix
<
double
>
transition_probabilities
(
num_label_states
,
num_label_states
);
transition_probabilities
=
0.05
,
0.90
,
0.05
,
0.05
,
0.05
,
0.90
,
0.90
,
0.05
,
0.05
;
void
perform_test
(
)
{
matrix
<
double
>
transition_probabilities
(
num_label_states
,
num_label_states
);
transition_probabilities
=
0.05
,
0.90
,
0.05
,
0.05
,
0.05
,
0.90
,
0.90
,
0.05
,
0.05
;
matrix
<
double
>
emission_probabilities
(
num_label_states
,
num_sample_states
);
emission_probabilities
=
0.5
,
0.5
,
0.0
,
0.0
,
0.5
,
0.5
,
0.5
,
0.0
,
0.5
;
print_spinner
();
matrix
<
double
>
emission_probabilities
(
num_label_states
,
num_sample_states
);
emission_probabilities
=
0.5
,
0.5
,
0.0
,
0.0
,
0.5
,
0.5
,
0.5
,
0.0
,
0.5
;
print_spinner
();
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
samples
;
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
labels
;
make_dataset
(
transition_probabilities
,
emission_probabilities
,
samples
,
labels
,
1000
);
dlog
<<
LINFO
<<
"samples.size(): "
<<
samples
.
size
();
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
samples
;
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
labels
;
make_dataset
(
transition_probabilities
,
emission_probabilities
,
samples
,
labels
,
1000
);
// print out some of the randomly sampled sequences
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
dlog
<<
LINFO
<<
"hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
i
]));
dlog
<<
LINFO
<<
"observed states: "
<<
trans
(
vector_to_matrix
(
samples
[
i
]));
dlog
<<
LINFO
<<
"******************************"
;
}
dlog
<<
LINFO
<<
"samples.size(): "
<<
samples
.
size
();
print_spinner
();
structural_sequence_labeling_trainer
<
fe_type
>
trainer
;
trainer
.
set_c
(
4
);
DLIB_TEST
(
trainer
.
get_c
()
==
4
);
trainer
.
set_num_threads
(
4
);
DLIB_TEST
(
trainer
.
get_num_threads
()
==
4
);
// print out some of the randomly sampled sequences
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
dlog
<<
LINFO
<<
"hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
i
]));
dlog
<<
LINFO
<<
"observed states: "
<<
trans
(
vector_to_matrix
(
samples
[
i
]));
dlog
<<
LINFO
<<
"******************************"
;
}
print_spinner
();
structural_sequence_labeling_trainer
<
feature_extractor
>
trainer
;
trainer
.
set_c
(
4
);
DLIB_TEST
(
trainer
.
get_c
()
==
4
);
trainer
.
set_num_threads
(
4
);
DLIB_TEST
(
trainer
.
get_num_threads
()
==
4
);
// Learn to do sequence labeling from the dataset
sequence_labeler
<
fe_type
>
labeler
=
trainer
.
train
(
samples
,
labels
);
std
::
vector
<
unsigned
long
>
predicted_labels
=
labeler
(
samples
[
0
]);
dlog
<<
LINFO
<<
"true hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
0
]));
dlog
<<
LINFO
<<
"predicted hidden states: "
<<
trans
(
vector_to_matrix
(
predicted_labels
));
// Learn to do sequence labeling from the dataset
sequence_labeler
<
feature_extractor
>
labeler
=
trainer
.
train
(
samples
,
labels
);
DLIB_TEST
(
vector_to_matrix
(
labels
[
0
])
==
vector_to_matrix
(
predicted_labels
));
std
::
vector
<
unsigned
long
>
predicted_labels
=
labeler
(
samples
[
0
]);
dlog
<<
LINFO
<<
"true hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
0
]));
dlog
<<
LINFO
<<
"predicted hidden states: "
<<
trans
(
vector_to_matrix
(
predicted_labels
));
DLIB_TEST
(
vector_to_matrix
(
labels
[
0
])
==
vector_to_matrix
(
predicted_labels
)
);
print_spinner
(
);
print_spinner
();
// We can also do cross-validation
matrix
<
double
>
confusion_matrix
;
confusion_matrix
=
cross_validate_sequence_labeler
(
trainer
,
samples
,
labels
,
4
);
dlog
<<
LINFO
<<
"cross-validation: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
double
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
print_spinner
();
// We can also do cross-validation
matrix
<
double
>
confusion_matrix
;
confusion_matrix
=
cross_validate_sequence_labeler
(
trainer
,
samples
,
labels
,
4
);
dlog
<<
LINFO
<<
"cross-validation: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
double
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
print_spinner
();
matrix
<
double
,
0
,
1
>
true_hmm_model_weights
=
log
(
join_cols
(
reshape_to_column_vector
(
transition_probabilities
),
reshape_to_column_vector
(
emission_probabilities
)));
sequence_labeler
<
fe_type
>
labeler_true
(
true_hmm_model_weights
);
matrix
<
double
,
0
,
1
>
true_hmm_model_weights
=
log
(
join_cols
(
reshape_to_column_vector
(
transition_probabilities
),
reshape_to_column_vector
(
emission_probabilities
)));
confusion_matrix
=
test_sequence_labeler
(
labeler_true
,
samples
,
labels
);
dlog
<<
LINFO
<<
"True HMM model: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
sequence_labeler
<
feature_extractor
>
labeler_true
(
true_hmm_model_weights
);
confusion_matrix
=
test_sequence_labeler
(
labeler_true
,
samples
,
labels
);
dlog
<<
LINFO
<<
"True HMM model: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
print_spinner
();
print_spinner
();
// Finally, the labeler can be serialized to disk just like most dlib objects.
ostringstream
sout
;
serialize
(
labeler
,
sout
);
sequence_labeler
<
fe_type
>
labeler2
;
// recall from disk
istringstream
sin
(
sout
.
str
());
deserialize
(
labeler2
,
sin
);
confusion_matrix
=
test_sequence_labeler
(
labeler2
,
samples
,
labels
);
dlog
<<
LINFO
<<
"deserialized labeler: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
}
// Finally, the labeler can be serialized to disk just like most dlib objects.
ostringstream
sout
;
serialize
(
labeler
,
sout
);
// ----------------------------------------------------------------------------------------
class
sequence_labeler_tester
:
public
tester
{
public:
sequence_labeler_tester
(
)
:
tester
(
"test_sequence_labeler"
,
"Runs tests on the sequence labeling code."
)
{}
sequence_labeler
<
feature_extractor
>
labeler2
;
// recall from disk
istringstream
sin
(
sout
.
str
());
deserialize
(
labeler2
,
sin
);
confusion_matrix
=
test_sequence_labeler
(
labeler2
,
samples
,
labels
);
dlog
<<
LINFO
<<
"deserialized labeler: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
void
perform_test
(
)
{
do_test
<
feature_extractor
>
();
DLIB_TEST
(
called_rejct_labeling
==
false
);
do_test
<
feature_extractor2
>
();
DLIB_TEST
(
called_rejct_labeling
==
true
);
}
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment