Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a233fceb
Commit
a233fceb
authored
Nov 17, 2017
by
Myle Ott
Browse files
Improve memory handling (recover from OOM and periodically empty caching allocator)
parent
be274623
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
60 deletions
+89
-60
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+89
-60
No files found.
fairseq/multiprocessing_trainer.py
View file @
a233fceb
...
...
@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
from
itertools
import
cycle
,
islice
import
math
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
nccl
,
utils
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
...
...
@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
_max_bsz_seen
=
0
def
_build_optimizer
(
self
):
if
self
.
args
.
optimizer
==
'adagrad'
:
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
...
...
@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
# forward pass
sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
)
for
rank
in
range
(
self
.
num_replicas
)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norms
=
Future
.
gen_list
([
grad_norms
,
ooms_bwd
=
Future
.
gen_
tuple_
list
([
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
])
...
...
@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
logging_output
[
'oom'
]
=
sum
(
ooms_fwd
)
+
sum
(
ooms_bwd
)
return
logging_output
...
...
@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
if
self
.
_sample
is
None
:
return
0
,
{}
sample_size
,
logging_output
,
oom
=
0
,
{},
False
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
return
sample_size
,
logging_output
return
sample_size
,
logging_output
,
oom
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
oom
=
False
if
self
.
loss
is
not
None
:
try
:
# backward pass
self
.
loss
.
backward
()
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
self
.
optimizer
.
zero_grad
()
else
:
raise
e
# get model parameters as a flattened (contiguous) tensor
flat_grads
=
self
.
_flat_model_grads
()
# all-reduce grads
nccl
.
all_reduce
(
flat_grads
)
# normalize grads
if
grad_denom
!=
0
:
flat_grads
.
div_
(
grad_denom
)
# all-reduce grads and rescale by grad_denom
self
.
_all_reduce_and_rescale_grads
(
grad_denom
)
# clip grads
grad_norm
=
self
.
_clip_grads_
(
flat_grads
,
self
.
args
.
clip_norm
)
# copy reduced grads back
self
.
_set_model_grads_
(
flat_grads
)
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
# take an optimization step
self
.
optimizer
.
step
()
...
...
@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
self
.
loss
=
None
return
grad_norm
return
grad_norm
,
oom
def
_model_grads
(
self
):
return
[
p
.
grad
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
def
_all_reduce_and_rescale_grads
(
self
,
grad_denom
,
buffer_size
=
10485760
):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
buffer_t
=
grads
[
0
].
new
(
math
.
ceil
(
buffer_size
/
grads
[
0
].
element_size
())).
zero_
()
buffer
=
[]
def
_flat_model_grads
(
self
):
grads
=
self
.
_model_grads
()
if
not
hasattr
(
self
,
'_flat_grads'
):
num_params
=
sum
(
g
.
data
.
numel
()
for
g
in
grads
)
self
.
_flat_grads
=
grads
[
0
].
data
.
new
(
num_params
)
def
all_reduce_buffer
():
# copy grads into buffer_t
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
self
.
_flat_grads
[
offset
:
offset
+
numel
].
copy_
(
grad
)
for
g
in
buffer
:
numel
=
g
.
numel
()
buffer_t
[
offset
:
offset
+
numel
].
copy_
(
g
.
view
(
-
1
))
offset
+=
numel
return
self
.
_flat_grads
def
_set_model_grads_
(
self
,
flat_grads
):
grads
=
self
.
_model_
grads
()
# all-reduce and rescale
nccl
.
all_reduce
(
buffer_t
[:
offset
])
buffer_t
.
div_
(
grad_denom
)
# copy all-reduced buffer back into
grads
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
grad
.
copy_
(
flat_grads
[
offset
:
offset
+
numel
])
for
g
in
buffer
:
numel
=
g
.
numel
()
g
.
view
(
-
1
).
copy_
(
buffer_t
[
offset
:
offset
+
numel
])
offset
+=
numel
assert
offset
==
flat_grads
.
numel
()
def
_clip_grads_
(
self
,
flat_grads
,
clipv
):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm
=
flat_grads
.
norm
()
if
clipv
>
0
and
norm
>
clipv
:
coef
=
max
(
norm
,
1e-6
)
/
clipv
flat_grads
.
div_
(
coef
)
return
norm
filled
=
0
for
g
in
grads
:
sz
=
g
.
numel
()
*
g
.
element_size
()
if
sz
>
buffer_size
:
# grad is bigger than buffer, all-reduce and rescale directly
nccl
.
all_reduce
(
g
)
g
.
div_
(
grad_denom
)
elif
filled
+
sz
>
buffer_size
:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer
()
buffer
=
[
g
]
filled
=
sz
else
:
# add grad to buffer
buffer
.
append
(
g
)
filled
+=
sz
if
len
(
buffer
)
>
0
:
all_reduce_buffer
()
def
valid_step
(
self
,
samples
):
"""Do forward pass in parallel."""
...
...
@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
# forward pass
_sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
_sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
,
eval
=
True
)
for
rank
in
range
(
self
.
num_replicas
)
])
assert
sum
(
ooms_fwd
)
==
0
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
...
...
@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if
sample
is
None
:
self
.
_sample
=
None
else
:
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
# clear the caching allocator if this is the largest sample we've seen
if
sample
[
'target'
].
size
(
0
)
>
self
.
_max_bsz_seen
:
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment