Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
c23ef609
Commit
c23ef609
authored
Nov 04, 2011
by
Davis King
Browse files
added more tests
parent
f18acdf8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
137 additions
and
77 deletions
+137
-77
dlib/test/sequence_labeler.cpp
dlib/test/sequence_labeler.cpp
+137
-77
No files found.
dlib/test/sequence_labeler.cpp
View file @
c23ef609
...
...
@@ -62,8 +62,58 @@ namespace
}
};
bool
called_rejct_labeling
=
false
;
class
feature_extractor2
{
public:
typedef
unsigned
long
sample_type
;
unsigned
long
num_features
()
const
{
return
num_label_states
*
num_label_states
+
num_label_states
*
num_sample_states
;
}
unsigned
long
order
()
const
{
return
1
;
}
unsigned
long
num_labels
()
const
{
return
num_label_states
;
}
template
<
typename
EXP
>
bool
reject_labeling
(
const
std
::
vector
<
sample_type
>&
x
,
const
matrix_exp
<
EXP
>&
y
,
unsigned
long
position
)
const
{
called_rejct_labeling
=
true
;
return
false
;
}
template
<
typename
feature_setter
,
typename
EXP
>
void
get_features
(
feature_setter
&
set_feature
,
const
std
::
vector
<
sample_type
>&
x
,
const
matrix_exp
<
EXP
>&
y
,
unsigned
long
position
)
const
{
if
(
y
.
size
()
>
1
)
set_feature
(
y
(
1
)
*
num_label_states
+
y
(
0
));
set_feature
(
num_label_states
*
num_label_states
+
y
(
0
)
*
num_sample_states
+
x
[
position
]);
}
};
void
serialize
(
const
feature_extractor
&
,
std
::
ostream
&
)
{}
void
deserialize
(
feature_extractor
&
,
std
::
istream
&
)
{}
void
serialize
(
const
feature_extractor2
&
,
std
::
ostream
&
)
{}
void
deserialize
(
feature_extractor2
&
,
std
::
istream
&
)
{}
// ----------------------------------------------------------------------------------------
...
...
@@ -174,18 +224,8 @@ namespace
// ----------------------------------------------------------------------------------------
class
sequence_labeler_tester
:
public
tester
{
public:
sequence_labeler_tester
(
)
:
tester
(
"test_sequence_labeler"
,
"Runs tests on the sequence labeling code."
)
{}
void
perform_test
(
)
template
<
typename
fe_type
>
void
do_test
()
{
matrix
<
double
>
transition_probabilities
(
num_label_states
,
num_label_states
);
transition_probabilities
=
0.05
,
0.90
,
0.05
,
...
...
@@ -216,7 +256,7 @@ namespace
}
print_spinner
();
structural_sequence_labeling_trainer
<
fe
ature_extractor
>
trainer
;
structural_sequence_labeling_trainer
<
fe
_type
>
trainer
;
trainer
.
set_c
(
4
);
DLIB_TEST
(
trainer
.
get_c
()
==
4
);
trainer
.
set_num_threads
(
4
);
...
...
@@ -225,7 +265,7 @@ namespace
// Learn to do sequence labeling from the dataset
sequence_labeler
<
fe
ature_extractor
>
labeler
=
trainer
.
train
(
samples
,
labels
);
sequence_labeler
<
fe
_type
>
labeler
=
trainer
.
train
(
samples
,
labels
);
std
::
vector
<
unsigned
long
>
predicted_labels
=
labeler
(
samples
[
0
]);
dlog
<<
LINFO
<<
"true hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
0
]));
...
...
@@ -252,7 +292,7 @@ namespace
matrix
<
double
,
0
,
1
>
true_hmm_model_weights
=
log
(
join_cols
(
reshape_to_column_vector
(
transition_probabilities
),
reshape_to_column_vector
(
emission_probabilities
)));
sequence_labeler
<
fe
ature_extractor
>
labeler_true
(
true_hmm_model_weights
);
sequence_labeler
<
fe
_type
>
labeler_true
(
true_hmm_model_weights
);
confusion_matrix
=
test_sequence_labeler
(
labeler_true
,
samples
,
labels
);
dlog
<<
LINFO
<<
"True HMM model: "
;
...
...
@@ -272,7 +312,7 @@ namespace
ostringstream
sout
;
serialize
(
labeler
,
sout
);
sequence_labeler
<
fe
ature_extractor
>
labeler2
;
sequence_labeler
<
fe
_type
>
labeler2
;
// recall from disk
istringstream
sin
(
sout
.
str
());
deserialize
(
labeler2
,
sin
);
...
...
@@ -283,6 +323,26 @@ namespace
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
}
// ----------------------------------------------------------------------------------------
class
sequence_labeler_tester
:
public
tester
{
public:
sequence_labeler_tester
(
)
:
tester
(
"test_sequence_labeler"
,
"Runs tests on the sequence labeling code."
)
{}
void
perform_test
(
)
{
do_test
<
feature_extractor
>
();
DLIB_TEST
(
called_rejct_labeling
==
false
);
do_test
<
feature_extractor2
>
();
DLIB_TEST
(
called_rejct_labeling
==
true
);
}
}
a
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment