Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
cc9ff97a
Commit
cc9ff97a
authored
Jul 07, 2013
by
Davis King
Browse files
Cleaned up python svm struct code a little.
parent
d0a054f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
19 deletions
+39
-19
python_examples/svm_struct.py
python_examples/svm_struct.py
+25
-17
tools/python/src/svm_struct.cpp
tools/python/src/svm_struct.cpp
+14
-2
No files found.
python_examples/svm_struct.py
View file @
cc9ff97a
#!/usr/bin/python
#!/usr/bin/python
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
#
#
#
# This is an example illustrating the use of the structural SVM solver from the dlib C++
# Library. This example will briefly introduce it and then walk through an example showing
# how to use it to create a simple multi-class classifier.
#
#
#
# COMPILING THE DLIB PYTHON INTERFACE
# COMPILING THE DLIB PYTHON INTERFACE
# Dlib comes with a compiled python interface for python 2.7 on MS Windows. If
# Dlib comes with a compiled python interface for python 2.7 on MS Windows. If
...
@@ -15,6 +18,7 @@
...
@@ -15,6 +18,7 @@
import
dlib
import
dlib
def
dot
(
a
,
b
):
def
dot
(
a
,
b
):
"Compute the dot product between the two vectors a and b."
return
sum
(
i
*
j
for
i
,
j
in
zip
(
a
,
b
))
return
sum
(
i
*
j
for
i
,
j
in
zip
(
a
,
b
))
...
@@ -23,30 +27,35 @@ class three_class_classifier_problem:
...
@@ -23,30 +27,35 @@ class three_class_classifier_problem:
be_verbose
=
True
be_verbose
=
True
epsilon
=
0.0001
epsilon
=
0.0001
def
__init__
(
self
,
samples
,
labels
):
def
__init__
(
self
,
samples
,
labels
):
self
.
num_samples
=
len
(
samples
)
self
.
num_samples
=
len
(
samples
)
self
.
num_dimensions
=
len
(
samples
[
0
])
*
3
self
.
num_dimensions
=
len
(
samples
[
0
])
*
3
self
.
samples
=
samples
self
.
samples
=
samples
self
.
labels
=
labels
self
.
labels
=
labels
def
make_psi
(
self
,
psi
,
vector
,
label
):
def
make_psi
(
self
,
vector
,
label
):
psi
=
dlib
.
vector
()
psi
.
resize
(
self
.
num_dimensions
)
psi
.
resize
(
self
.
num_dimensions
)
dims
=
len
(
vector
)
dims
=
len
(
vector
)
if
(
label
==
1
):
if
(
label
==
0
):
for
i
in
range
(
0
,
dims
):
for
i
in
range
(
0
,
dims
):
psi
[
i
]
=
vector
[
i
]
psi
[
i
]
=
vector
[
i
]
elif
(
label
==
2
):
elif
(
label
==
1
):
for
i
in
range
(
dims
,
2
*
dims
):
for
i
in
range
(
dims
,
2
*
dims
):
psi
[
i
]
=
vector
[
i
-
dims
]
psi
[
i
]
=
vector
[
i
-
dims
]
else
:
else
:
# the label must be 2
for
i
in
range
(
2
*
dims
,
3
*
dims
):
for
i
in
range
(
2
*
dims
,
3
*
dims
):
psi
[
i
]
=
vector
[
i
-
2
*
dims
]
psi
[
i
]
=
vector
[
i
-
2
*
dims
]
return
psi
def
get_truth_joint_feature_vector
(
self
,
idx
):
return
self
.
make_psi
(
self
.
samples
[
idx
],
self
.
labels
[
idx
])
def
get_truth_joint_feature_vector
(
self
,
idx
,
psi
):
self
.
make_psi
(
psi
,
self
.
samples
[
idx
],
self
.
labels
[
idx
])
def
separation_oracle
(
self
,
idx
,
current_solution
,
psi
):
def
separation_oracle
(
self
,
idx
,
current_solution
):
samp
=
samples
[
idx
]
samp
=
samples
[
idx
]
dims
=
len
(
samp
)
dims
=
len
(
samp
)
scores
=
[
0
,
0
,
0
]
scores
=
[
0
,
0
,
0
]
...
@@ -56,29 +65,28 @@ class three_class_classifier_problem:
...
@@ -56,29 +65,28 @@ class three_class_classifier_problem:
scores
[
2
]
=
dot
(
current_solution
[
2
*
dims
:
3
*
dims
],
samp
)
scores
[
2
]
=
dot
(
current_solution
[
2
*
dims
:
3
*
dims
],
samp
)
# Add in the loss-augmentation
# Add in the loss-augmentation
if
(
labels
[
idx
]
!=
1
):
if
(
labels
[
idx
]
!=
0
):
scores
[
0
]
+=
1
scores
[
0
]
+=
1
if
(
labels
[
idx
]
!=
2
):
if
(
labels
[
idx
]
!=
1
):
scores
[
1
]
+=
1
scores
[
1
]
+=
1
if
(
labels
[
idx
]
!=
3
):
if
(
labels
[
idx
]
!=
2
):
scores
[
2
]
+=
1
scores
[
2
]
+=
1
# Now figure out which classifier has the largest loss-augmented score.
# Now figure out which classifier has the largest loss-augmented score.
max_scoring_label
=
scores
.
index
(
max
(
scores
))
+
1
max_scoring_label
=
scores
.
index
(
max
(
scores
))
if
(
max_scoring_label
==
labels
[
idx
]):
if
(
max_scoring_label
==
labels
[
idx
]):
loss
=
0
loss
=
0
else
:
else
:
loss
=
1
loss
=
1
self
.
make_psi
(
psi
,
samp
,
max_scoring_label
)
psi
=
self
.
make_psi
(
samp
,
max_scoring_label
)
return
loss
return
loss
,
psi
samples
=
[
[
0
,
0
,
1
],
[
0
,
1
,
0
],
[
1
,
0
,
0
]];
samples
=
[[
0
,
0
,
1
],
[
0
,
1
,
0
],
[
1
,
0
,
0
]];
labels
=
[
1
,
2
,
3
]
labels
=
[
0
,
1
,
2
]
problem
=
three_class_classifier_problem
(
samples
,
labels
)
problem
=
three_class_classifier_problem
(
samples
,
labels
)
weights
=
dlib
.
solve_structural_svm_problem
(
problem
)
weights
=
dlib
.
solve_structural_svm_problem
(
problem
)
...
...
tools/python/src/svm_struct.cpp
View file @
cc9ff97a
...
@@ -37,7 +37,7 @@ public:
...
@@ -37,7 +37,7 @@ public:
feature_vector_type
&
psi
feature_vector_type
&
psi
)
const
)
const
{
{
problem
.
attr
(
"get_truth_joint_feature_vector"
)(
idx
,
boost
::
ref
(
psi
));
psi
=
extract
<
feature_vector_type
&>
(
problem
.
attr
(
"get_truth_joint_feature_vector"
)(
idx
));
}
}
virtual
void
separation_oracle
(
virtual
void
separation_oracle
(
...
@@ -47,7 +47,19 @@ public:
...
@@ -47,7 +47,19 @@ public:
feature_vector_type
&
psi
feature_vector_type
&
psi
)
const
)
const
{
{
loss
=
extract
<
double
>
(
problem
.
attr
(
"separation_oracle"
)(
idx
,
boost
::
ref
(
current_solution
),
boost
::
ref
(
psi
)));
object
res
=
problem
.
attr
(
"separation_oracle"
)(
idx
,
boost
::
ref
(
current_solution
));
pyassert
(
len
(
res
)
==
2
,
"separation_oracle() must return two objects, the loss and the psi vector"
);
// let the user supply the output arguments in any order.
if
(
extract
<
double
>
(
res
[
0
]).
check
())
{
loss
=
extract
<
double
>
(
res
[
0
]);
psi
=
extract
<
feature_vector_type
&>
(
res
[
1
]);
}
else
{
psi
=
extract
<
feature_vector_type
&>
(
res
[
0
]);
loss
=
extract
<
double
>
(
res
[
1
]);
}
}
}
private:
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment