Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
ea2a5184
Commit
ea2a5184
authored
May 16, 2011
by
Davis King
Browse files
Cleaned up the code a little by pulling the caching logic out into its
own class.
parent
007e218e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
100 deletions
+158
-100
dlib/svm/structural_svm_problem.h
dlib/svm/structural_svm_problem.h
+152
-95
dlib/svm/structural_svm_problem_abstract.h
dlib/svm/structural_svm_problem_abstract.h
+6
-5
No files found.
dlib/svm/structural_svm_problem.h
View file @
ea2a5184
...
...
@@ -17,10 +17,134 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template
<
typename
matrix_type
,
typename
feature_vector_type_
=
matrix_type
typename
structural_svm_problem
>
class
structural_svm_problem
:
public
oca_problem
<
matrix_type
>
class
cache_element_structural_svm
{
public:
cache_element_structural_svm
(
)
:
prob
(
0
),
sample_idx
(
0
)
{}
typedef
typename
structural_svm_problem
::
scalar_type
scalar_type
;
typedef
typename
structural_svm_problem
::
matrix_type
matrix_type
;
typedef
typename
structural_svm_problem
::
feature_vector_type
feature_vector_type
;
void
init
(
const
structural_svm_problem
*
prob_
,
const
long
idx
)
/*!
ensures
- This object will be a cache for the idx-th sample in the given
structural_svm_problem.
!*/
{
prob
=
prob_
;
sample_idx
=
idx
;
loss
.
clear
();
psi
.
clear
();
lru_count
.
clear
();
prob
->
get_truth_joint_feature_vector
(
idx
,
true_psi
);
}
void
get_truth_joint_feature_vector_cached
(
feature_vector_type
&
psi
)
const
{
psi
=
true_psi
;
}
void
separation_oracle_cached
(
const
bool
skip_cache
,
const
scalar_type
&
cur_risk_lower_bound
,
const
matrix_type
&
current_solution
,
scalar_type
&
out_loss
,
feature_vector_type
&
out_psi
)
const
{
if
(
!
skip_cache
)
{
scalar_type
best_risk
=
-
std
::
numeric_limits
<
scalar_type
>::
infinity
();
unsigned
long
best_idx
=
0
;
using
sparse_vector
::
dot
;
using
dlib
::
dot
;
const
scalar_type
dot_true_psi
=
dot
(
true_psi
,
current_solution
);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long
max_lru_count
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
loss
.
size
();
++
i
)
{
const
scalar_type
risk
=
loss
[
i
]
+
dot
(
psi
[
i
],
current_solution
)
-
dot_true_psi
;
if
(
risk
>
best_risk
)
{
best_risk
=
risk
;
out_loss
=
loss
[
i
];
best_idx
=
i
;
}
if
(
lru_count
[
i
]
>
max_lru_count
)
max_lru_count
=
lru_count
[
i
];
}
if
(
best_risk
-
cur_risk_lower_bound
>
prob
->
get_epsilon
())
{
out_psi
=
psi
[
best_idx
];
lru_count
[
best_idx
]
=
max_lru_count
+
1
;
return
;
}
}
prob
->
separation_oracle
(
sample_idx
,
current_solution
,
out_loss
,
out_psi
);
// if the cache is full
if
(
loss
.
size
()
>=
prob
->
get_max_cache_size
())
{
// find least recently used cache entry for idx-th sample
const
long
i
=
index_of_min
(
vector_to_matrix
(
lru_count
));
// save our new data in the cache
loss
[
i
]
=
out_loss
;
psi
[
i
]
=
out_psi
;
const
long
max_use
=
max
(
vector_to_matrix
(
lru_count
));
// Make sure this new cache entry has the best lru count since we have used
// it most recently.
lru_count
[
i
]
=
max_use
+
1
;
}
else
{
loss
.
push_back
(
out_loss
);
psi
.
push_back
(
out_psi
);
long
max_use
=
1
;
if
(
lru_count
.
size
()
!=
0
)
max_use
=
max
(
vector_to_matrix
(
lru_count
))
+
1
;
lru_count
.
push_back
(
lru_count
.
size
());
}
}
const
structural_svm_problem
*
prob
;
long
sample_idx
;
mutable
feature_vector_type
true_psi
;
mutable
std
::
vector
<
scalar_type
>
loss
;
mutable
std
::
vector
<
feature_vector_type
>
psi
;
mutable
std
::
vector
<
long
>
lru_count
;
};
// ----------------------------------------------------------------------------------------
template
<
typename
matrix_type_
,
typename
feature_vector_type_
=
matrix_type_
>
class
structural_svm_problem
:
public
oca_problem
<
matrix_type_
>
{
public:
/*!
...
...
@@ -35,11 +159,11 @@ namespace dlib
- if (cache.size() != 0) then
- cache.size() == get_num_samples()
- true_psis.size() == get_num_samples()
- for all i: cache[i] == the cached results of calls to separation_oracle()
for the i-th sample.
!*/
typedef
matrix_type_
matrix_type
;
typedef
typename
matrix_type
::
type
scalar_type
;
typedef
feature_vector_type_
feature_vector_type
;
...
...
@@ -193,36 +317,30 @@ namespace dlib
feature_vector_type
ftemp
;
const
unsigned
long
num
=
get_num_samples
();
// initialize psi_true and a few other things if we haven't done so already.
if
(
psi_true
.
size
()
==
0
)
{
// initialize the cache if necessary.
if
(
cache
.
size
()
==
0
&&
max_cache_size
!=
0
)
{
cache
.
resize
(
get_num_samples
());
for
(
unsigned
long
i
=
0
;
i
<
cache
.
size
();
++
i
)
cache
[
i
].
init
(
this
,
i
);
}
// initialize psi_true if necessary.
if
(
psi_true
.
size
()
==
0
)
{
psi_true
.
set_size
(
w
.
size
(),
1
);
psi_true
=
0
;
// If the cache is enabled then populate the true_psis array. But
// in either case sum them all up and store the result in psi_true.
if
(
max_cache_size
!=
0
)
{
true_psis
.
resize
(
num
);
for
(
unsigned
long
i
=
0
;
i
<
num
;
++
i
)
{
get_truth_joint_feature_vector
(
i
,
true_psis
[
i
]);
sparse_vector
::
subtract_from
(
psi_true
,
true_psis
[
i
]);
}
}
else
{
for
(
unsigned
long
i
=
0
;
i
<
num
;
++
i
)
{
if
(
cache
.
size
()
==
0
)
get_truth_joint_feature_vector
(
i
,
ftemp
);
else
cache
[
i
].
get_truth_joint_feature_vector_cached
(
ftemp
);
sparse_vector
::
subtract_from
(
psi_true
,
ftemp
);
}
}
}
subgradient
=
psi_true
;
scalar_type
total_loss
=
0
;
...
...
@@ -259,90 +377,29 @@ namespace dlib
feature_vector_type
&
psi
)
const
{
if
(
!
skip_cache
&&
max_cache_size
!=
0
)
{
scalar_type
best_risk
=
-
std
::
numeric_limits
<
scalar_type
>::
infinity
();
unsigned
long
best_idx
=
0
;
cache_record
&
rec
=
cache
[
idx
];
using
sparse_vector
::
dot
;
using
dlib
::
dot
;
const
scalar_type
dot_true_psi
=
dot
(
true_psis
[
idx
],
current_solution
);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long
max_lru_count
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
rec
.
loss
.
size
();
++
i
)
{
const
scalar_type
risk
=
rec
.
loss
[
i
]
+
dot
(
rec
.
psi
[
i
],
current_solution
)
-
dot_true_psi
;
if
(
risk
>
best_risk
)
{
best_risk
=
risk
;
loss
=
rec
.
loss
[
i
];
best_idx
=
i
;
}
if
(
rec
.
lru_count
[
i
]
>
max_lru_count
)
max_lru_count
=
rec
.
lru_count
[
i
];
}
if
(
best_risk
-
cur_risk_lower_bound
>
eps
)
if
(
cache
.
size
()
==
0
)
{
psi
=
rec
.
psi
[
best_idx
];
rec
.
lru_count
[
best_idx
]
=
max_lru_count
+
1
;
return
;
}
}
separation_oracle
(
idx
,
current_solution
,
loss
,
psi
);
if
(
cache
.
size
()
!=
0
)
{
if
(
cache
[
idx
].
loss
.
size
()
<
max_cache_size
)
{
cache
[
idx
].
loss
.
push_back
(
loss
);
cache
[
idx
].
psi
.
push_back
(
psi
);
long
max_use
=
1
;
if
(
cache
[
idx
].
lru_count
.
size
()
!=
0
)
max_use
=
max
(
vector_to_matrix
(
cache
[
idx
].
lru_count
))
+
1
;
cache
[
idx
].
lru_count
.
push_back
(
cache
[
idx
].
lru_count
.
size
());
}
else
{
// find least recently used cache entry for idx-th sample
const
long
i
=
index_of_min
(
vector_to_matrix
(
cache
[
idx
].
lru_count
));
// save our new data in the cache
cache
[
idx
].
loss
[
i
]
=
loss
;
cache
[
idx
].
psi
[
i
]
=
psi
;
const
long
max_use
=
max
(
vector_to_matrix
(
cache
[
idx
].
lru_count
));
// Make sure this new cache entry has the best lru count since we have used
// it most recently.
cache
[
idx
].
lru_count
[
i
]
=
max_use
+
1
;
}
cache
[
idx
].
separation_oracle_cached
(
skip_cache
,
cur_risk_lower_bound
,
current_solution
,
loss
,
psi
);
}
}
private:
struct
cache_record
{
std
::
vector
<
scalar_type
>
loss
;
std
::
vector
<
feature_vector_type
>
psi
;
std
::
vector
<
long
>
lru_count
;
};
mutable
scalar_type
cur_risk_lower_bound
;
mutable
matrix_type
psi_true
;
scalar_type
eps
;
mutable
bool
verbose
;
mutable
std
::
vector
<
feature_vector_type
>
true_psis
;
mutable
std
::
vector
<
cache_
record
>
cache
;
mutable
std
::
vector
<
cache_
element_structural_svm
<
structural_svm_problem
>
>
cache
;
mutable
bool
skip_cache
;
unsigned
long
max_cache_size
;
...
...
dlib/svm/structural_svm_problem_abstract.h
View file @
ea2a5184
...
...
@@ -13,15 +13,15 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template
<
typename
matrix_type
,
typename
feature_vector_type_
=
matrix_type
typename
matrix_type
_
,
typename
feature_vector_type_
=
matrix_type
_
>
class
structural_svm_problem
:
public
oca_problem
<
matrix_type
>
class
structural_svm_problem
:
public
oca_problem
<
matrix_type
_
>
{
public:
/*!
REQUIREMENTS ON matrix_type
- matrix_type == a dlib::matrix capable of storing column vectors
REQUIREMENTS ON matrix_type
_
- matrix_type
_
== a dlib::matrix capable of storing column vectors
REQUIREMENTS ON feature_vector_type_
- feature_vector_type_ == a dlib::matrix capable of storing column vectors
...
...
@@ -81,6 +81,7 @@ namespace dlib
paper.
!*/
typedef
matrix_type_
matrix_type
;
typedef
typename
matrix_type
::
type
scalar_type
;
typedef
feature_vector_type_
feature_vector_type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment