Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
4593ebfa
Commit
4593ebfa
authored
Sep 28, 2017
by
Myle Ott
Committed by
GitHub
Sep 28, 2017
Browse files
Fix handling of partially-empty initial batch (#11)
parent
03c4a716
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
4 deletions
+18
-4
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+18
-4
No files found.
fairseq/multiprocessing_trainer.py
View file @
4593ebfa
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
Train a network on multiple GPUs using multiprocessing.
Train a network on multiple GPUs using multiprocessing.
"""
"""
from
itertools
import
cycle
,
islice
import
torch
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
...
@@ -48,6 +49,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -48,6 +49,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
])
])
self
.
_grads_initialized
=
False
def
_async_init
(
self
,
rank
,
device_id
,
args
,
model
,
nccl_uid
):
def
_async_init
(
self
,
rank
,
device_id
,
args
,
model
,
nccl_uid
):
"""Initialize child processes."""
"""Initialize child processes."""
self
.
args
=
args
self
.
args
=
args
...
@@ -121,8 +124,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -121,8 +124,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Do forward, backward and gradient step in parallel."""
"""Do forward, backward and gradient step in parallel."""
assert
isinstance
(
criterion
,
FairseqCriterion
)
assert
isinstance
(
criterion
,
FairseqCriterion
)
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples
=
False
if
not
self
.
_grads_initialized
:
replace_empty_samples
=
True
self
.
_grads_initialized
=
True
# scatter sample across GPUs
# scatter sample across GPUs
self
.
_scatter_samples
(
samples
)
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
criterion
.
prepare
(
samples
)
criterion
.
prepare
(
samples
)
# forward pass, backward pass and gradient step
# forward pass, backward pass and gradient step
...
@@ -234,10 +244,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -234,10 +244,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
lr_scheduler
.
step
(
val_loss
,
epoch
)
self
.
lr_scheduler
.
step
(
val_loss
,
epoch
)
return
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
return
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
def
_scatter_samples
(
self
,
samples
,
volatile
=
False
):
def
_scatter_samples
(
self
,
samples
,
volatile
=
False
,
replace_empty_samples
=
False
):
"""Split and distribute a sample across GPUs."""
"""Split and distribute a sample across GPUs."""
# Pad with None until its size is equal to the number of replicas.
if
not
replace_empty_samples
:
samples
=
samples
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
# pad with None until its size is equal to the number of replicas
samples
=
samples
+
[
None
]
*
(
self
.
num_replicas
-
len
(
samples
))
else
:
# pad by cycling through the given samples
samples
=
list
(
islice
(
cycle
(
samples
),
self
.
num_replicas
))
Future
.
gen_list
([
Future
.
gen_list
([
self
.
call_async
(
rank
,
'_async_prepare_sample'
,
sample
=
samples
[
rank
],
volatile
=
volatile
)
self
.
call_async
(
rank
,
'_async_prepare_sample'
,
sample
=
samples
[
rank
],
volatile
=
volatile
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment