Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
160337da
Commit
160337da
authored
Nov 17, 2013
by
Davis King
Browse files
Made the one_vs_one_trainer and one_vs_all_trainer objects multithreaded
so they can run each binary trainer on a different core.
parent
525f2a52
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
254 additions
and
76 deletions
+254
-76
dlib/svm.h
dlib/svm.h
+0
-2
dlib/svm/cross_validate_multiclass_trainer.h
dlib/svm/cross_validate_multiclass_trainer.h
+0
-1
dlib/svm/one_vs_all_trainer.h
dlib/svm/one_vs_all_trainer.h
+104
-34
dlib/svm/one_vs_all_trainer_abstract.h
dlib/svm/one_vs_all_trainer_abstract.h
+22
-4
dlib/svm/one_vs_one_trainer.h
dlib/svm/one_vs_one_trainer.h
+102
-29
dlib/svm/one_vs_one_trainer_abstract.h
dlib/svm/one_vs_one_trainer_abstract.h
+22
-4
dlib/svm_threaded.h
dlib/svm_threaded.h
+2
-0
dlib/test/one_vs_all_trainer.cpp
dlib/test/one_vs_all_trainer.cpp
+1
-1
dlib/test/one_vs_one_trainer.cpp
dlib/test/one_vs_one_trainer.cpp
+1
-1
No files found.
dlib/svm.h
View file @
160337da
...
...
@@ -32,7 +32,6 @@
#include "svm/svr_trainer.h"
#include "svm/one_vs_one_decision_function.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
...
...
@@ -42,7 +41,6 @@
#include "svm/cross_validate_assignment_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
#include "svm/structural_svm_problem.h"
#include "svm/sequence_labeler.h"
...
...
dlib/svm/cross_validate_multiclass_trainer.h
View file @
160337da
...
...
@@ -5,7 +5,6 @@
#include <vector>
#include "../matrix.h"
#include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
...
...
dlib/svm/one_vs_all_trainer.h
View file @
160337da
...
...
@@ -16,6 +16,7 @@
#include "../any.h"
#include <map>
#include <set>
#include "../threads.h"
namespace
dlib
{
...
...
@@ -39,7 +40,8 @@ namespace dlib
one_vs_all_trainer
(
)
:
verbose
(
false
)
verbose
(
false
),
num_threads
(
4
)
{}
void
set_trainer
(
...
...
@@ -70,6 +72,19 @@ namespace dlib
verbose
=
false
;
}
void
set_num_threads
(
unsigned
long
num
)
{
num_threads
=
num
;
}
unsigned
long
get_num_threads
(
)
const
{
return
num_threads
;
}
struct
invalid_label
:
public
dlib
::
error
{
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l_
...
...
@@ -96,62 +111,117 @@ namespace dlib
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
std
::
vector
<
scalar_type
>
labels
;
typename
trained_function_type
::
binary_function_table
dfs
;
// make sure we have a trainer object for each of the label types.
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
{
labels
.
clear
();
const
label_type
l
=
distinct_labels
[
i
];
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
// setup one of the one vs all training sets
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
if
(
itr
==
trainers
.
end
()
&&
default_trainer
.
is_empty
())
{
if
(
all_labels
[
k
]
==
l
)
labels
.
push_back
(
+
1
);
else
labels
.
push_back
(
-
1
);
std
::
ostringstream
sout
;
sout
<<
"In one_vs_all_trainer, no trainer registered for the "
<<
l
<<
" label."
;
throw
invalid_label
(
sout
.
str
(),
l
);
}
}
if
(
verbose
)
{
std
::
cout
<<
"Training classifier for "
<<
l
<<
" vs. all"
<<
std
::
endl
;
}
// now do the training
parallel_for_helper
helper
(
all_samples
,
all_labels
,
default_trainer
,
trainers
,
verbose
,
distinct_labels
);
parallel_for
(
num_threads
,
0
,
distinct_labels
.
size
(),
helper
,
500
);
// now train a binary classifier using the samples we selected
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
if
(
helper
.
error_message
.
size
()
!=
0
)
{
throw
dlib
::
error
(
"binary trainer threw while training one vs. all classifier. Error was: "
+
helper
.
error_message
);
}
return
trained_function_type
(
helper
.
dfs
);
}
if
(
itr
!=
trainers
.
end
())
{
dfs
[
l
]
=
itr
->
second
.
train
(
all_samples
,
labels
);
}
else
if
(
default_trainer
.
is_empty
()
==
false
)
private:
typedef
std
::
map
<
label_type
,
any_trainer
>
binary_function_table
;
struct
parallel_for_helper
{
parallel_for_helper
(
const
std
::
vector
<
sample_type
>&
all_samples_
,
const
std
::
vector
<
label_type
>&
all_labels_
,
const
any_trainer
&
default_trainer_
,
const
binary_function_table
&
trainers_
,
const
bool
verbose_
,
const
std
::
vector
<
label_type
>&
distinct_labels_
)
:
all_samples
(
all_samples_
),
all_labels
(
all_labels_
),
default_trainer
(
default_trainer_
),
trainers
(
trainers_
),
verbose
(
verbose_
),
distinct_labels
(
distinct_labels_
)
{}
void
operator
()(
long
i
)
const
{
try
{
dfs
[
l
]
=
default_trainer
.
train
(
all_samples
,
labels
);
std
::
vector
<
scalar_type
>
labels
;
const
label_type
l
=
distinct_labels
[
i
];
// setup one of the one vs all training sets
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
{
if
(
all_labels
[
k
]
==
l
)
labels
.
push_back
(
+
1
);
else
labels
.
push_back
(
-
1
);
}
if
(
verbose
)
{
auto_mutex
lock
(
class_mutex
);
std
::
cout
<<
"Training classifier for "
<<
l
<<
" vs. all"
<<
std
::
endl
;
}
any_trainer
trainer
;
// now train a binary classifier using the samples we selected
{
auto_mutex
lock
(
class_mutex
);
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
if
(
itr
!=
trainers
.
end
())
trainer
=
itr
->
second
;
else
trainer
=
default_trainer
;
}
any_decision_function
<
sample_type
,
scalar_type
>
binary_df
=
trainer
.
train
(
all_samples
,
labels
);
auto_mutex
lock
(
class_mutex
);
dfs
[
l
]
=
binary_df
;
}
else
catch
(
std
::
exception
&
e
)
{
std
::
ostringstream
sout
;
sout
<<
"In one_vs_all_trainer, no trainer registered for the "
<<
l
<<
" label."
;
throw
invalid_label
(
sout
.
str
(),
l
);
auto_mutex
lock
(
class_mutex
);
error_message
=
e
.
what
();
}
}
return
trained_function_type
(
dfs
);
}
mutable
typename
trained_function_type
::
binary_function_table
dfs
;
mutex
class_mutex
;
mutable
std
::
string
error_message
;
private:
const
std
::
vector
<
sample_type
>&
all_samples
;
const
std
::
vector
<
label_type
>&
all_labels
;
const
any_trainer
&
default_trainer
;
const
binary_function_table
&
trainers
;
const
bool
verbose
;
const
std
::
vector
<
label_type
>&
distinct_labels
;
};
any_trainer
default_trainer
;
typedef
std
::
map
<
label_type
,
any_trainer
>
binary_function_table
;
binary_function_table
trainers
;
bool
verbose
;
unsigned
long
num_threads
;
};
...
...
dlib/svm/one_vs_all_trainer_abstract.h
View file @
160337da
...
...
@@ -55,10 +55,11 @@ namespace dlib
);
/*!
ensures
- this object is properly initialized
- this object will not be verbose unless be_verbose() is called
- no binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
- This object is properly initialized.
- This object will not be verbose unless be_verbose() is called.
- No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train().
- #get_num_threads() == 4
!*/
void
set_trainer
(
...
...
@@ -96,6 +97,23 @@ namespace dlib
- this object will not print anything to standard out
!*/
void
set_num_threads
(
unsigned
long
num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned
long
get_num_threads
(
)
const
;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct
invalid_label
:
public
dlib
::
error
{
/*!
...
...
dlib/svm/one_vs_one_trainer.h
View file @
160337da
...
...
@@ -17,6 +17,7 @@
#include "../any.h"
#include <map>
#include <set>
#include "../threads.h"
namespace
dlib
{
...
...
@@ -40,7 +41,8 @@ namespace dlib
one_vs_one_trainer
(
)
:
verbose
(
false
)
verbose
(
false
),
num_threads
(
4
)
{}
void
set_trainer
(
...
...
@@ -72,6 +74,19 @@ namespace dlib
verbose
=
false
;
}
void
set_num_threads
(
unsigned
long
num
)
{
num_threads
=
num
;
}
unsigned
long
get_num_threads
(
)
const
{
return
num_threads
;
}
struct
invalid_label
:
public
dlib
::
error
{
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l1_
,
const
label_type
&
l2_
...
...
@@ -98,20 +113,70 @@ namespace dlib
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
scalar_type
>
labels
;
typename
trained_function_type
::
binary_function_table
dfs
;
// fill pairs with all the pairs of labels.
std
::
vector
<
unordered_pair
<
label_type
>
>
pairs
;
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
{
for
(
unsigned
long
j
=
i
+
1
;
j
<
distinct_labels
.
size
();
++
j
)
{
samples
.
clear
();
labels
.
clear
();
pairs
.
push_back
(
unordered_pair
<
label_type
>
(
distinct_labels
[
i
],
distinct_labels
[
j
]));
const
unordered_pair
<
label_type
>
p
(
distinct_labels
[
i
],
distinct_labels
[
j
]);
// make sure we have a trainer for this pair
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
pairs
.
back
());
if
(
itr
==
trainers
.
end
()
&&
default_trainer
.
is_empty
())
{
std
::
ostringstream
sout
;
sout
<<
"In one_vs_one_trainer, no trainer registered for the ("
<<
pairs
.
back
().
first
<<
", "
<<
pairs
.
back
().
second
<<
") label pair."
;
throw
invalid_label
(
sout
.
str
(),
pairs
.
back
().
first
,
pairs
.
back
().
second
);
}
}
}
// Now train on all the label pairs.
parallel_for_helper
helper
(
all_samples
,
all_labels
,
default_trainer
,
trainers
,
verbose
,
pairs
);
parallel_for
(
num_threads
,
0
,
pairs
.
size
(),
helper
,
500
);
if
(
helper
.
error_message
.
size
()
!=
0
)
{
throw
dlib
::
error
(
"binary trainer threw while training one vs. one classifier. Error was: "
+
helper
.
error_message
);
}
return
trained_function_type
(
helper
.
dfs
);
}
private:
typedef
std
::
map
<
unordered_pair
<
label_type
>
,
any_trainer
>
binary_function_table
;
struct
parallel_for_helper
{
parallel_for_helper
(
const
std
::
vector
<
sample_type
>&
all_samples_
,
const
std
::
vector
<
label_type
>&
all_labels_
,
const
any_trainer
&
default_trainer_
,
const
binary_function_table
&
trainers_
,
const
bool
verbose_
,
const
std
::
vector
<
unordered_pair
<
label_type
>
>&
pairs_
)
:
all_samples
(
all_samples_
),
all_labels
(
all_labels_
),
default_trainer
(
default_trainer_
),
trainers
(
trainers_
),
verbose
(
verbose_
),
pairs
(
pairs_
)
{}
void
operator
()(
long
i
)
const
{
try
{
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
scalar_type
>
labels
;
const
unordered_pair
<
label_type
>
p
=
pairs
[
i
];
// pick out the samples corresponding to these two classes
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
...
...
@@ -128,43 +193,51 @@ namespace dlib
}
}
if
(
verbose
)
{
auto_mutex
lock
(
class_mutex
);
std
::
cout
<<
"Training classifier for "
<<
p
.
first
<<
" vs. "
<<
p
.
second
<<
std
::
endl
;
}
any_trainer
trainer
;
// now train a binary classifier using the samples we selected
{
auto_mutex
lock
(
class_mutex
);
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
p
);
if
(
itr
!=
trainers
.
end
())
{
dfs
[
p
]
=
itr
->
second
.
train
(
samples
,
labels
);
}
else
if
(
default_trainer
.
is_empty
()
==
false
)
{
dfs
[
p
]
=
default_trainer
.
train
(
samples
,
labels
);
}
else
{
std
::
ostringstream
sout
;
sout
<<
"In one_vs_one_trainer, no trainer registered for the ("
<<
p
.
first
<<
", "
<<
p
.
second
<<
") label pair."
;
throw
invalid_label
(
sout
.
str
(),
p
.
first
,
p
.
second
);
trainer
=
itr
->
second
;
else
trainer
=
default_trainer
;
}
any_decision_function
<
sample_type
,
scalar_type
>
binary_df
=
trainer
.
train
(
samples
,
labels
);
auto_mutex
lock
(
class_mutex
);
dfs
[
p
]
=
binary_df
;
}
catch
(
std
::
exception
&
e
)
{
auto_mutex
lock
(
class_mutex
);
error_message
=
e
.
what
();
}
}
return
trained_function_type
(
dfs
);
}
mutable
typename
trained_function_type
::
binary_function_table
dfs
;
mutex
class_mutex
;
mutable
std
::
string
error_message
;
private:
const
std
::
vector
<
sample_type
>&
all_samples
;
const
std
::
vector
<
label_type
>&
all_labels
;
const
any_trainer
&
default_trainer
;
const
binary_function_table
&
trainers
;
const
bool
verbose
;
const
std
::
vector
<
unordered_pair
<
label_type
>
>&
pairs
;
};
any_trainer
default_trainer
;
typedef
std
::
map
<
unordered_pair
<
label_type
>
,
any_trainer
>
binary_function_table
;
binary_function_table
trainers
;
bool
verbose
;
unsigned
long
num_threads
;
};
...
...
dlib/svm/one_vs_one_trainer_abstract.h
View file @
160337da
...
...
@@ -55,10 +55,11 @@ namespace dlib
);
/*!
ensures
- this object is properly initialized
- this object will not be verbose unless be_verbose() is called
- no binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
- This object is properly initialized
- This object will not be verbose unless be_verbose() is called.
- No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train().
- #get_num_threads() == 4
!*/
void
set_trainer
(
...
...
@@ -99,6 +100,23 @@ namespace dlib
- this object will not print anything to standard out
!*/
void
set_num_threads
(
unsigned
long
num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned
long
get_num_threads
(
)
const
;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct
invalid_label
:
public
dlib
::
error
{
/*!
...
...
dlib/svm_threaded.h
View file @
160337da
...
...
@@ -19,6 +19,8 @@
#include "svm/structural_graph_labeling_trainer.h"
#include "svm/cross_validate_graph_labeling_trainer.h"
#include "svm/svm_multiclass_linear_trainer.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/one_vs_all_trainer.h"
#endif // DLIB_SVm_THREADED_HEADER
...
...
dlib/test/one_vs_all_trainer.cpp
View file @
160337da
...
...
@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm.h>
#include <dlib/svm
_threaded
.h>
#include <vector>
#include <sstream>
...
...
dlib/test/one_vs_one_trainer.cpp
View file @
160337da
...
...
@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm.h>
#include <dlib/svm
_threaded
.h>
#include <dlib/statistics.h>
#include <vector>
#include <sstream>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment