Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
0246088a
Commit
0246088a
authored
Jul 29, 2012
by
Davis King
Browse files
Added a per node loss interface for the structural_graph_labeling_trainer.
parent
7aab9f71
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
17 deletions
+70
-17
dlib/svm/structural_graph_labeling_trainer.h
dlib/svm/structural_graph_labeling_trainer.h
+30
-12
dlib/svm/structural_graph_labeling_trainer_abstract.h
dlib/svm/structural_graph_labeling_trainer_abstract.h
+40
-5
No files found.
dlib/svm/structural_graph_labeling_trainer.h
View file @
0246088a
...
...
@@ -167,20 +167,23 @@ namespace dlib
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
),
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
)
==
true
&&
(
losses
.
size
()
==
0
||
sizes_match
(
labels
,
losses
)
==
true
)
&&
all_values_are_nonnegative
(
losses
)
==
true
,
"
\t
void structural_graph_labeling_trainer::train()"
<<
"
\n\t
Invalid inputs were given to this function."
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
labels.size(): "
<<
labels
.
size
()
<<
"
\n\t
this: "
<<
this
);
<<
"
\n\t
losses.size(): "
<<
losses
.
size
()
<<
"
\n\t
sizes_match(labels,losses): "
<<
sizes_match
(
labels
,
losses
)
<<
"
\n\t
all_values_are_nonnegative(losses): "
<<
all_values_are_nonnegative
(
losses
)
<<
"
\n\t
this: "
<<
this
);
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
structural_svm_graph_labeling_problem
<
graph_type
>
prob
(
samples
,
labels
,
losses
,
num_threads
);
if
(
verbose
)
...
...
@@ -189,8 +192,11 @@ namespace dlib
prob
.
set_c
(
C
);
prob
.
set_epsilon
(
eps
);
prob
.
set_max_cache_size
(
max_cache_size
);
if
(
prob
.
get_losses
().
size
()
==
0
)
{
prob
.
set_loss_on_positive_class
(
loss_pos
);
prob
.
set_loss_on_negative_class
(
loss_neg
);
}
matrix
<
double
,
0
,
1
>
w
;
solver
(
prob
,
w
,
prob
.
get_num_edge_weights
());
...
...
@@ -201,6 +207,18 @@ namespace dlib
return
graph_labeler
<
vector_type
>
(
edge_weights
,
node_weights
);
}
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
)
const
{
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
return
train
(
samples
,
labels
,
losses
);
}
private:
template
<
typename
T
>
...
...
dlib/svm/structural_graph_labeling_trainer_abstract.h
View file @
0246088a
...
...
@@ -212,14 +212,49 @@ namespace dlib
requires
- is_graph_labeling_problem(samples,labels) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a
graph_labeler on the given samples/labels training pairs.
The idea is to learn to predict a label given an input sample.
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the
graph
new_sample.
- F(new_sample) == The predicted labels for the nodes in the
graph
new_sample.
!*/
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
;
/*!
requires
- is_graph_labeling_problem(samples,labels) == true
- if (losses.size() != 0) then
- sizes_match(labels, losses) == true
- all_values_are_nonnegative(losses) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
- if (losses.size() == 0) then
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- The losses argument is effectively ignored if its size is zero.
- else
- Each node in the training data has its own loss value defined by the
corresponding entry of losses. In particular, this means that the
node with label labels[i][j] incurs a loss of losses[i][j] if it is
incorrectly labeled.
- The get_loss_on_positive_class() and get_loss_on_negative_class()
parameters are ignored. Only losses is used in this case.
!*/
};
// ----------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment