Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a233fceb
Commit
a233fceb
authored
Nov 17, 2017
by
Myle Ott
Browse files
Improve memory handling (recover from OOM and periodically empty caching allocator)
parent
be274623
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
60 deletions
+89
-60
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+89
-60
No files found.
fairseq/multiprocessing_trainer.py
View file @
a233fceb
...
@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
...
@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
"""
from
itertools
import
cycle
,
islice
from
itertools
import
cycle
,
islice
import
math
import
torch
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
nccl
,
utils
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
from
fairseq.nag
import
NAG
...
@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler
# initialize LR scheduler
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
_max_bsz_seen
=
0
def
_build_optimizer
(
self
):
def
_build_optimizer
(
self
):
if
self
.
args
.
optimizer
==
'adagrad'
:
if
self
.
args
.
optimizer
==
'adagrad'
:
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
...
@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
# forward pass
# forward pass
sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
)
self
.
call_async
(
rank
,
'_async_forward'
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
])
])
# backward pass, all-reduce gradients and take an optimization step
# backward pass, all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norms
=
Future
.
gen_list
([
grad_norms
,
ooms_bwd
=
Future
.
gen_
tuple_
list
([
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
])
])
...
@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
logging_output
[
'oom'
]
=
sum
(
ooms_fwd
)
+
sum
(
ooms_bwd
)
return
logging_output
return
logging_output
...
@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
if
self
.
_sample
is
None
:
sample_size
,
logging_output
,
oom
=
0
,
{},
False
return
0
,
{}
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
return
sample_size
,
logging_output
return
sample_size
,
logging_output
,
oom
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
oom
=
False
if
self
.
loss
is
not
None
:
if
self
.
loss
is
not
None
:
# backward pass
try
:
self
.
loss
.
backward
()
# backward pass
self
.
loss
.
backward
()
# get model parameters as a flattened (contiguous) tensor
except
RuntimeError
as
e
:
flat_grads
=
self
.
_flat_model_grads
()
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
# all-reduce grads
oom
=
True
nccl
.
all_reduce
(
flat_grads
)
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
self
.
optimizer
.
zero_grad
()
else
:
raise
e
# normalize grads
# all-reduce grads and rescale by grad_denom
if
grad_denom
!=
0
:
self
.
_all_reduce_and_rescale_grads
(
grad_denom
)
flat_grads
.
div_
(
grad_denom
)
# clip grads
# clip grads
grad_norm
=
self
.
_clip_grads_
(
flat_grads
,
self
.
args
.
clip_norm
)
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
# copy reduced grads back
self
.
_set_model_grads_
(
flat_grads
)
# take an optimization step
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
...
@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
# reset loss
self
.
loss
=
None
self
.
loss
=
None
return
grad_norm
return
grad_norm
,
oom
def
_model_grads
(
self
):
def
_all_reduce_and_rescale_grads
(
self
,
grad_denom
,
buffer_size
=
10485760
):
return
[
p
.
grad
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
"""All-reduce and rescale gradients in chunks of the specified size."""
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
def
_flat_model_grads
(
self
):
buffer_t
=
grads
[
0
].
new
(
math
.
ceil
(
buffer_size
/
grads
[
0
].
element_size
())).
zero_
()
grads
=
self
.
_model_grads
()
buffer
=
[]
if
not
hasattr
(
self
,
'_flat_grads'
):
num_params
=
sum
(
g
.
data
.
numel
()
for
g
in
grads
)
def
all_reduce_buffer
():
self
.
_flat_grads
=
grads
[
0
].
data
.
new
(
num_params
)
# copy grads into buffer_t
offset
=
0
offset
=
0
for
grad
in
grads
:
for
g
in
buffer
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
g
.
numel
()
numel
=
grad
.
numel
()
buffer_t
[
offset
:
offset
+
numel
].
copy_
(
g
.
view
(
-
1
))
self
.
_flat_grads
[
offset
:
offset
+
numel
].
copy_
(
grad
)
offset
+=
numel
offset
+=
numel
# all-reduce and rescale
return
self
.
_flat_grads
nccl
.
all_reduce
(
buffer_t
[:
offset
])
buffer_t
.
div_
(
grad_denom
)
def
_set_model_grads_
(
self
,
flat_grads
):
# copy all-reduced buffer back into grads
grads
=
self
.
_model_grads
()
offset
=
0
offset
=
0
for
g
in
buffer
:
for
grad
in
grads
:
numel
=
g
.
numel
()
grad
=
grad
.
data
.
view
(
-
1
)
g
.
view
(
-
1
).
copy_
(
buffer_t
[
offset
:
offset
+
numel
])
numel
=
grad
.
numel
()
offset
+=
numel
grad
.
copy_
(
flat_grads
[
offset
:
offset
+
numel
])
offset
+=
numel
filled
=
0
assert
offset
==
flat_grads
.
numel
()
for
g
in
grads
:
sz
=
g
.
numel
()
*
g
.
element_size
()
def
_clip_grads_
(
self
,
flat_grads
,
clipv
):
if
sz
>
buffer_size
:
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
# grad is bigger than buffer, all-reduce and rescale directly
norm
=
flat_grads
.
norm
()
nccl
.
all_reduce
(
g
)
if
clipv
>
0
and
norm
>
clipv
:
g
.
div_
(
grad_denom
)
coef
=
max
(
norm
,
1e-6
)
/
clipv
elif
filled
+
sz
>
buffer_size
:
flat_grads
.
div_
(
coef
)
# buffer is full, all-reduce and replace buffer with grad
return
norm
all_reduce_buffer
()
buffer
=
[
g
]
filled
=
sz
else
:
# add grad to buffer
buffer
.
append
(
g
)
filled
+=
sz
if
len
(
buffer
)
>
0
:
all_reduce_buffer
()
def
valid_step
(
self
,
samples
):
def
valid_step
(
self
,
samples
):
"""Do forward pass in parallel."""
"""Do forward pass in parallel."""
...
@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
# forward pass
# forward pass
_sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
_sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
,
eval
=
True
)
self
.
call_async
(
rank
,
'_async_forward'
,
eval
=
True
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
])
])
assert
sum
(
ooms_fwd
)
==
0
# aggregate logging output
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
...
@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if
sample
is
None
:
if
sample
is
None
:
self
.
_sample
=
None
self
.
_sample
=
None
else
:
else
:
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
# clear the caching allocator if this is the largest sample we've seen
if
sample
[
'target'
].
size
(
0
)
>
self
.
_max_bsz_seen
:
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment