Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a6155337
"docs/vscode:/vscode.git/clone" did not exist on "61fba9da0fcae0ba5f3fc426442302e9fef0443c"
Commit
a6155337
authored
Sep 19, 2017
by
Myle Ott
Browse files
Better training support when GPUs are in "exclusive mode"
parent
a8bc4d0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
36 deletions
+26
-36
fairseq/data.py
fairseq/data.py
+0
-1
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+26
-35
No files found.
fairseq/data.py
View file @
a6155337
...
...
@@ -105,7 +105,6 @@ class LanguageDatasets(object):
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
pin_memory
=
torch
.
cuda
.
is_available
(),
collate_fn
=
PaddingCollater
(
self
.
src_dict
.
pad
()),
batch_sampler
=
batch_sampler
)
...
...
fairseq/multiprocessing_trainer.py
View file @
a6155337
...
...
@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
assert
isinstance
(
criterion
,
FairseqCriterion
)
# scatter sample across GPUs
samples
,
data_events
=
self
.
_scatter_samples
(
samples
)
self
.
_scatter_samples
(
samples
)
criterion
.
prepare
(
samples
)
# forward pass, backward pass and gradient step
losses
=
[
self
.
call_async
(
rank
,
'_async_train_step'
,
sample
=
samples
[
rank
],
criterion
=
criterion
,
data_event
=
event
)
for
rank
,
event
in
enumerate
(
data_events
)
self
.
call_async
(
rank
,
'_async_train_step'
,
criterion
=
criterion
)
for
rank
in
range
(
self
.
num_replicas
)
]
# aggregate losses and gradient norms
...
...
@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
loss
,
grad_norms
[
0
]
def
_async_train_step
(
self
,
rank
,
device_id
,
sample
,
criterion
,
data_event
):
data_event
.
wait
()
def
_async_train_step
(
self
,
rank
,
device_id
,
criterion
):
self
.
model
.
train
()
# zero grads even if net_input is None, since we will all-reduce them
...
...
@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads
loss
=
0
if
sample
is
not
None
:
net_output
=
self
.
model
(
**
sample
[
'net_input'
])
loss_
=
criterion
(
net_output
,
sample
)
if
self
.
_
sample
is
not
None
:
net_output
=
self
.
model
(
**
self
.
_
sample
[
'net_input'
])
loss_
=
criterion
(
net_output
,
self
.
_
sample
)
loss_
.
backward
()
loss
=
loss_
.
data
[
0
]
...
...
@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
valid_step
(
self
,
samples
,
criterion
):
"""Do forward pass in parallel."""
# scatter sample across GPUs
samples
,
data_events
=
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
criterion
.
prepare
(
samples
)
# forward pass
losses
=
[
self
.
call_async
(
rank
,
'_async_valid_step'
,
sample
=
samples
[
rank
],
criterion
=
criterion
,
data_event
=
event
)
for
rank
,
event
in
enumerate
(
data_events
)
self
.
call_async
(
rank
,
'_async_valid_step'
,
criterion
=
criterion
)
for
rank
in
range
(
self
.
num_replicas
)
]
# aggregate losses
...
...
@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
loss
def
_async_valid_step
(
self
,
rank
,
device_id
,
sample
,
criterion
,
data_event
):
if
sample
is
None
:
def
_async_valid_step
(
self
,
rank
,
device_id
,
criterion
):
if
self
.
_
sample
is
None
:
return
0
data_event
.
wait
()
self
.
model
.
eval
()
net_output
=
self
.
model
(
**
sample
[
'net_input'
])
loss
=
criterion
(
net_output
,
sample
)
net_output
=
self
.
model
(
**
self
.
_
sample
[
'net_input'
])
loss
=
criterion
(
net_output
,
self
.
_
sample
)
return
loss
.
data
[
0
]
def
get_lr
(
self
):
...
...
@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_scatter_samples
(
self
,
samples
,
volatile
=
False
):
"""Split and distribute a sample across GPUs."""
res
=
[
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
for
sample
,
device_id
in
zip
(
samples
,
self
.
device_ids
)]
# Pad with None until its size is equal to the number of replicas.
res
=
r
es
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
# Synchronize GPU devices after data is sent to prevent
# race conditions.
events
=
[]
for
d
in
self
.
device_ids
:
with
torch
.
cuda
.
device
(
d
):
event
=
torch
.
cuda
.
Event
(
interprocess
=
Tru
e
)
event
.
record
()
events
.
append
(
event
)
return
res
,
events
samples
=
sampl
es
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
Future
.
gen_list
([
self
.
call_async
(
rank
,
'_async_prepare_sample'
,
sample
=
samples
[
rank
],
volatile
=
volatile
)
for
rank
in
range
(
self
.
num_replicas
)
])
def
_async_prepare_sample
(
self
,
rank
,
device_id
,
sample
,
volatil
e
)
:
if
sample
is
None
:
self
.
_sample
=
None
else
:
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment