Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
97474376
Commit
97474376
authored
May 23, 2016
by
Davis King
Browse files
Changed code to avoid recreating thread_local cuda context objects.
parent
e55afabd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
dlib/dnn/trainer.h
dlib/dnn/trainer.h
+13
-4
No files found.
dlib/dnn/trainer.h
View file @
97474376
...
...
@@ -535,7 +535,15 @@ namespace dlib
std
::
vector
<
tensor
*>
reference_params
;
visit_layer_parameters
(
devices
[
0
]
->
net
,
[
&
](
size_t
,
tensor
&
t
)
{
reference_params
.
push_back
(
&
t
);
});
thread_pool
tp
(
devices
.
size
());
// We make separate thread pools with just one thread in them because we want
// to make sure each device is always executed on the same thread. We care
// about this because there are thread_local context variables for some cuda
// components and they get regenerated when the current cuda device changes.
// Recreating them over and over is somewhat expensive so we want to avoid
// that.
std
::
vector
<
std
::
shared_ptr
<
thread_pool
>>
tp
;
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
tp
.
push_back
(
std
::
make_shared
<
thread_pool
>
(
1
));
size_t
iteration
=
0
;
...
...
@@ -546,7 +554,7 @@ namespace dlib
// right version for unsupervised or supervised training based on the type
// of label_type.
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
tp
.
add_task_by_value
([
&
,
i
](
double
&
loss
){
loss
=
compute_parameter_gradients
(
i
,
next_job
,
pick_which_run_update
);
},
losses
[
i
]);
tp
[
i
]
->
add_task_by_value
([
&
,
i
](
double
&
loss
){
loss
=
compute_parameter_gradients
(
i
,
next_job
,
pick_which_run_update
);
},
losses
[
i
]);
// aggregate loss values from all the network computations.
double
theloss
=
0
;
for
(
auto
&&
loss
:
losses
)
...
...
@@ -597,9 +605,10 @@ namespace dlib
// Now apply all the updates to each device.
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
tp
.
add_task_by_value
([
&
,
i
](){
if
(
next_job
.
have_data
[
i
])
update_parameters
(
i
);
});
tp
[
i
]
->
add_task_by_value
([
&
,
i
](){
if
(
next_job
.
have_data
[
i
])
update_parameters
(
i
);
});
// and wait for the updates to all happen.
tp
.
wait_for_all_tasks
();
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
tp
[
i
]
->
wait_for_all_tasks
();
// Evey now and then force all the parameters to be the same just to make
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment