Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
44fd6f42
Commit
44fd6f42
authored
Jul 07, 2013
by
Davis King
Browse files
refined example
parent
aa46752d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
8 deletions
+12
-8
python_examples/svm_struct.py
python_examples/svm_struct.py
+12
-8
No files found.
python_examples/svm_struct.py
View file @
44fd6f42
...
...
@@ -152,7 +152,7 @@ class three_class_classifier_problem:
# There are also a number of optional arguments:
# epsilon is the stopping tolerance. The optimizer will run until R(w) is within
# epsilon of its optimal value. If you don't set this then it defaults to 0.001
# epsilon of its optimal value. If you don't set this then it defaults to 0.001
.
#epsilon = 1e-13
# Uncomment this and the optimizer will print its progress to standard out. You will
...
...
@@ -172,9 +172,9 @@ class three_class_classifier_problem:
def
__init__
(
self
,
samples
,
labels
):
# dlib.solve_structural_svm_problem()
also
expects the class to have num_samples
#
and
num_dimensions fields. These fields
are expected to
contain the number of
#
training
samples and the dimensionality of the
psi
feature vector respectively.
# dlib.solve_structural_svm_problem() expects the class to have num_samples
and
# num_dimensions fields. These fields
should
contain the number of
training
# samples and the dimensionality of the
PSI
feature vector respectively.
self
.
num_samples
=
len
(
samples
)
self
.
num_dimensions
=
len
(
samples
[
0
])
*
3
...
...
@@ -237,7 +237,8 @@ class three_class_classifier_problem:
# it the current value of the parameter weights and the separation_oracle() is supposed
# to find the label that most violates the structural SVM objective function for the
# idx-th sample. Then the separation oracle reports the corresponding PSI vector and
# loss value. # To be more precise, separation_oracle() has the following contract:
# loss value. To state this more precisely, the separation_oracle() member function
# has the following contract:
# requires
# - 0 <= idx < self.num_samples
# - len(current_solution) == self.num_dimensions
...
...
@@ -266,6 +267,9 @@ class three_class_classifier_problem:
# Add in the loss-augmentation. Recall that we maximize LOSS(idx,y) + F(X,y) in
# the separate oracle, not just F(X,y) as we normally would in predict_label().
# Therefore, we must add in this extra amount to account for the loss-augmentation.
# For our simple multi-class classifier, we incur a loss of 1 if we don't predict
# the correct label and a loss of 0 if we get the right label.
if
(
self
.
labels
[
idx
]
!=
0
):
scores
[
0
]
+=
1
if
(
self
.
labels
[
idx
]
!=
1
):
...
...
@@ -275,8 +279,8 @@ class three_class_classifier_problem:
# Now figure out which classifier has the largest loss-augmented score.
max_scoring_label
=
scores
.
index
(
max
(
scores
))
#
We incur a loss of 1 if we don't predict the correct label and a loss of 0 if we
#
get the right answer
.
#
And finally record the loss that was associated with that predicted label.
#
Again, the loss is 1 if the label is incorrect and 0 otherwise
.
if
(
max_scoring_label
==
self
.
labels
[
idx
]):
loss
=
0
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment