Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a6155337
Commit
a6155337
authored
Sep 19, 2017
by
Myle Ott
Browse files
Better training support when GPUs are in "exclusive mode"
parent
a8bc4d0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
36 deletions
+26
-36
fairseq/data.py
fairseq/data.py
+0
-1
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+26
-35
No files found.
fairseq/data.py
View file @
a6155337
...
@@ -105,7 +105,6 @@ class LanguageDatasets(object):
...
@@ -105,7 +105,6 @@ class LanguageDatasets(object):
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
pin_memory
=
torch
.
cuda
.
is_available
(),
collate_fn
=
PaddingCollater
(
self
.
src_dict
.
pad
()),
collate_fn
=
PaddingCollater
(
self
.
src_dict
.
pad
()),
batch_sampler
=
batch_sampler
)
batch_sampler
=
batch_sampler
)
...
...
fairseq/multiprocessing_trainer.py
View file @
a6155337
...
@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
assert
isinstance
(
criterion
,
FairseqCriterion
)
assert
isinstance
(
criterion
,
FairseqCriterion
)
# scatter sample across GPUs
# scatter sample across GPUs
samples
,
data_events
=
self
.
_scatter_samples
(
samples
)
self
.
_scatter_samples
(
samples
)
criterion
.
prepare
(
samples
)
criterion
.
prepare
(
samples
)
# forward pass, backward pass and gradient step
# forward pass, backward pass and gradient step
losses
=
[
losses
=
[
self
.
call_async
(
rank
,
'_async_train_step'
,
sample
=
samples
[
rank
],
self
.
call_async
(
rank
,
'_async_train_step'
,
criterion
=
criterion
)
criterion
=
criterion
,
data_event
=
event
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
,
event
in
enumerate
(
data_events
)
]
]
# aggregate losses and gradient norms
# aggregate losses and gradient norms
...
@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
loss
,
grad_norms
[
0
]
return
loss
,
grad_norms
[
0
]
def
_async_train_step
(
self
,
rank
,
device_id
,
sample
,
criterion
,
data_event
):
def
_async_train_step
(
self
,
rank
,
device_id
,
criterion
):
data_event
.
wait
()
self
.
model
.
train
()
self
.
model
.
train
()
# zero grads even if net_input is None, since we will all-reduce them
# zero grads even if net_input is None, since we will all-reduce them
...
@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads
# calculate loss and grads
loss
=
0
loss
=
0
if
sample
is
not
None
:
if
self
.
_
sample
is
not
None
:
net_output
=
self
.
model
(
**
sample
[
'net_input'
])
net_output
=
self
.
model
(
**
self
.
_
sample
[
'net_input'
])
loss_
=
criterion
(
net_output
,
sample
)
loss_
=
criterion
(
net_output
,
self
.
_
sample
)
loss_
.
backward
()
loss_
.
backward
()
loss
=
loss_
.
data
[
0
]
loss
=
loss_
.
data
[
0
]
...
@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
valid_step
(
self
,
samples
,
criterion
):
def
valid_step
(
self
,
samples
,
criterion
):
"""Do forward pass in parallel."""
"""Do forward pass in parallel."""
# scatter sample across GPUs
# scatter sample across GPUs
samples
,
data_events
=
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
criterion
.
prepare
(
samples
)
criterion
.
prepare
(
samples
)
# forward pass
# forward pass
losses
=
[
losses
=
[
self
.
call_async
(
rank
,
'_async_valid_step'
,
sample
=
samples
[
rank
],
self
.
call_async
(
rank
,
'_async_valid_step'
,
criterion
=
criterion
)
criterion
=
criterion
,
data_event
=
event
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
,
event
in
enumerate
(
data_events
)
]
]
# aggregate losses
# aggregate losses
...
@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
loss
return
loss
def
_async_valid_step
(
self
,
rank
,
device_id
,
sample
,
criterion
,
data_event
):
def
_async_valid_step
(
self
,
rank
,
device_id
,
criterion
):
if
sample
is
None
:
if
self
.
_
sample
is
None
:
return
0
return
0
data_event
.
wait
()
self
.
model
.
eval
()
self
.
model
.
eval
()
net_output
=
self
.
model
(
**
sample
[
'net_input'
])
net_output
=
self
.
model
(
**
self
.
_
sample
[
'net_input'
])
loss
=
criterion
(
net_output
,
sample
)
loss
=
criterion
(
net_output
,
self
.
_
sample
)
return
loss
.
data
[
0
]
return
loss
.
data
[
0
]
def
get_lr
(
self
):
def
get_lr
(
self
):
...
@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_scatter_samples
(
self
,
samples
,
volatile
=
False
):
def
_scatter_samples
(
self
,
samples
,
volatile
=
False
):
"""Split and distribute a sample across GPUs."""
"""Split and distribute a sample across GPUs."""
res
=
[
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
for
sample
,
device_id
in
zip
(
samples
,
self
.
device_ids
)]
# Pad with None until its size is equal to the number of replicas.
# Pad with None until its size is equal to the number of replicas.
res
=
r
es
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
samples
=
sampl
es
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
# Synchronize GPU devices after data is sent to prevent
Future
.
gen_list
([
# race conditions.
self
.
call_async
(
rank
,
'_async_prepare_sample'
,
sample
=
samples
[
rank
],
volatile
=
volatile
)
events
=
[]
for
rank
in
range
(
self
.
num_replicas
)
for
d
in
self
.
device_ids
:
])
with
torch
.
cuda
.
device
(
d
):
event
=
torch
.
cuda
.
Event
(
interprocess
=
Tru
e
)
def
_async_prepare_sample
(
self
,
rank
,
device_id
,
sample
,
volatil
e
)
:
event
.
record
()
if
sample
is
None
:
events
.
append
(
event
)
self
.
_sample
=
None
else
:
return
res
,
events
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment