Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
0246088a
You need to sign in or sign up before continuing.
Commit
0246088a
authored
Jul 29, 2012
by
Davis King
Browse files
Added a per node loss interface for the structural_graph_labeling_trainer.
parent
7aab9f71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
17 deletions
+70
-17
dlib/svm/structural_graph_labeling_trainer.h
dlib/svm/structural_graph_labeling_trainer.h
+30
-12
dlib/svm/structural_graph_labeling_trainer_abstract.h
dlib/svm/structural_graph_labeling_trainer_abstract.h
+40
-5
No files found.
dlib/svm/structural_graph_labeling_trainer.h
View file @
0246088a
...
@@ -167,20 +167,23 @@ namespace dlib
...
@@ -167,20 +167,23 @@ namespace dlib
>
>
const
graph_labeler
<
vector_type
>
train
(
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
)
const
{
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
)
==
true
&&
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
),
(
losses
.
size
()
==
0
||
sizes_match
(
labels
,
losses
)
==
true
)
&&
"
\t
void structural_graph_labeling_trainer::train()"
all_values_are_nonnegative
(
losses
)
==
true
,
<<
"
\n\t
Invalid inputs were given to this function."
"
\t
void structural_graph_labeling_trainer::train()"
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
Invalid inputs were given to this function."
<<
"
\n\t
labels.size(): "
<<
labels
.
size
()
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
this: "
<<
this
<<
"
\n\t
labels.size(): "
<<
labels
.
size
()
);
<<
"
\n\t
losses.size(): "
<<
losses
.
size
()
<<
"
\n\t
sizes_match(labels,losses): "
<<
sizes_match
(
labels
,
losses
)
<<
"
\n\t
all_values_are_nonnegative(losses): "
<<
all_values_are_nonnegative
(
losses
)
<<
"
\n\t
this: "
<<
this
);
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
structural_svm_graph_labeling_problem
<
graph_type
>
prob
(
samples
,
labels
,
losses
,
num_threads
);
structural_svm_graph_labeling_problem
<
graph_type
>
prob
(
samples
,
labels
,
losses
,
num_threads
);
if
(
verbose
)
if
(
verbose
)
...
@@ -189,8 +192,11 @@ namespace dlib
...
@@ -189,8 +192,11 @@ namespace dlib
prob
.
set_c
(
C
);
prob
.
set_c
(
C
);
prob
.
set_epsilon
(
eps
);
prob
.
set_epsilon
(
eps
);
prob
.
set_max_cache_size
(
max_cache_size
);
prob
.
set_max_cache_size
(
max_cache_size
);
prob
.
set_loss_on_positive_class
(
loss_pos
);
if
(
prob
.
get_losses
().
size
()
==
0
)
prob
.
set_loss_on_negative_class
(
loss_neg
);
{
prob
.
set_loss_on_positive_class
(
loss_pos
);
prob
.
set_loss_on_negative_class
(
loss_neg
);
}
matrix
<
double
,
0
,
1
>
w
;
matrix
<
double
,
0
,
1
>
w
;
solver
(
prob
,
w
,
prob
.
get_num_edge_weights
());
solver
(
prob
,
w
,
prob
.
get_num_edge_weights
());
...
@@ -201,6 +207,18 @@ namespace dlib
...
@@ -201,6 +207,18 @@ namespace dlib
return
graph_labeler
<
vector_type
>
(
edge_weights
,
node_weights
);
return
graph_labeler
<
vector_type
>
(
edge_weights
,
node_weights
);
}
}
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
)
const
{
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
return
train
(
samples
,
labels
,
losses
);
}
private:
private:
template
<
typename
T
>
template
<
typename
T
>
...
...
dlib/svm/structural_graph_labeling_trainer_abstract.h
View file @
0246088a
...
@@ -212,14 +212,49 @@ namespace dlib
...
@@ -212,14 +212,49 @@ namespace dlib
requires
requires
- is_graph_labeling_problem(samples,labels) == true
- is_graph_labeling_problem(samples,labels) == true
ensures
ensures
- Uses the structural_svm_graph_labeling_problem to train a
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
graph_labeler on the given samples/labels training pairs.
on the given samples/labels training pairs. The idea is to learn to
The idea is to learn to predict a label given an input sample.
predict a label given an input sample.
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- returns a function F with the following properties:
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the
- F(new_sample) == The predicted labels for the nodes in the
graph
graph
new_sample.
new_sample.
!*/
!*/
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
;
/*!
requires
- is_graph_labeling_problem(samples,labels) == true
- if (losses.size() != 0) then
- sizes_match(labels, losses) == true
- all_values_are_nonnegative(losses) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
- if (losses.size() == 0) then
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- The losses argument is effectively ignored if its size is zero.
- else
- Each node in the training data has its own loss value defined by the
corresponding entry of losses. In particular, this means that the
node with label labels[i][j] incurs a loss of losses[i][j] if it is
incorrectly labeled.
- The get_loss_on_positive_class() and get_loss_on_negative_class()
parameters are ignored. Only losses is used in this case.
!*/
};
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment