Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a233fceb
"src/vscode:/vscode.git/clone" did not exist on "fa736e321d85a49cd761fccc6dd70a66b562aa1c"
Commit
a233fceb
authored
Nov 17, 2017
by
Myle Ott
Browse files
Improve memory handling (recover from OOM and periodically empty caching allocator)
parent
be274623
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
60 deletions
+89
-60
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+89
-60
No files found.
fairseq/multiprocessing_trainer.py
View file @
a233fceb
...
...
@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
from
itertools
import
cycle
,
islice
import
math
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
nccl
,
utils
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
...
...
@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
_max_bsz_seen
=
0
def
_build_optimizer
(
self
):
if
self
.
args
.
optimizer
==
'adagrad'
:
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
...
...
@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
# forward pass
sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
)
for
rank
in
range
(
self
.
num_replicas
)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norms
=
Future
.
gen_list
([
grad_norms
,
ooms_bwd
=
Future
.
gen_
tuple_
list
([
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
])
...
...
@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
logging_output
[
'oom'
]
=
sum
(
ooms_fwd
)
+
sum
(
ooms_bwd
)
return
logging_output
...
...
@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
if
self
.
_sample
is
None
:
return
0
,
{}
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
sample_size
,
logging_output
,
oom
=
0
,
{},
False
if
self
.
_sample
is
not
None
:
try
:
# calculate loss and sample size
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
self
.
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
raise
e
return
sample_size
,
logging_output
return
sample_size
,
logging_output
,
oom
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
oom
=
False
if
self
.
loss
is
not
None
:
# backward pass
self
.
loss
.
backward
()
# get model parameters as a flattened (contiguous) tensor
flat_grads
=
self
.
_flat_model_grads
()
# all-reduce grads
nccl
.
all_reduce
(
flat_grads
)
try
:
# backward pass
self
.
loss
.
backward
()
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
oom
=
True
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
self
.
optimizer
.
zero_grad
()
else
:
raise
e
# normalize grads
if
grad_denom
!=
0
:
flat_grads
.
div_
(
grad_denom
)
# all-reduce grads and rescale by grad_denom
self
.
_all_reduce_and_rescale_grads
(
grad_denom
)
# clip grads
grad_norm
=
self
.
_clip_grads_
(
flat_grads
,
self
.
args
.
clip_norm
)
# copy reduced grads back
self
.
_set_model_grads_
(
flat_grads
)
grad_norm
=
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
)
# take an optimization step
self
.
optimizer
.
step
()
...
...
@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
self
.
loss
=
None
return
grad_norm
def
_model_grads
(
self
):
return
[
p
.
grad
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
def
_flat_model_grads
(
self
):
grads
=
self
.
_model_grads
()
if
not
hasattr
(
self
,
'_flat_grads'
):
num_params
=
sum
(
g
.
data
.
numel
()
for
g
in
grads
)
self
.
_flat_grads
=
grads
[
0
].
data
.
new
(
num_params
)
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
self
.
_flat_grads
[
offset
:
offset
+
numel
].
copy_
(
grad
)
offset
+=
numel
return
self
.
_flat_grads
def
_set_model_grads_
(
self
,
flat_grads
):
grads
=
self
.
_model_grads
()
offset
=
0
for
grad
in
grads
:
grad
=
grad
.
data
.
view
(
-
1
)
numel
=
grad
.
numel
()
grad
.
copy_
(
flat_grads
[
offset
:
offset
+
numel
])
offset
+=
numel
assert
offset
==
flat_grads
.
numel
()
def
_clip_grads_
(
self
,
flat_grads
,
clipv
):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm
=
flat_grads
.
norm
()
if
clipv
>
0
and
norm
>
clipv
:
coef
=
max
(
norm
,
1e-6
)
/
clipv
flat_grads
.
div_
(
coef
)
return
norm
return
grad_norm
,
oom
def
_all_reduce_and_rescale_grads
(
self
,
grad_denom
,
buffer_size
=
10485760
):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
buffer_t
=
grads
[
0
].
new
(
math
.
ceil
(
buffer_size
/
grads
[
0
].
element_size
())).
zero_
()
buffer
=
[]
def
all_reduce_buffer
():
# copy grads into buffer_t
offset
=
0
for
g
in
buffer
:
numel
=
g
.
numel
()
buffer_t
[
offset
:
offset
+
numel
].
copy_
(
g
.
view
(
-
1
))
offset
+=
numel
# all-reduce and rescale
nccl
.
all_reduce
(
buffer_t
[:
offset
])
buffer_t
.
div_
(
grad_denom
)
# copy all-reduced buffer back into grads
offset
=
0
for
g
in
buffer
:
numel
=
g
.
numel
()
g
.
view
(
-
1
).
copy_
(
buffer_t
[
offset
:
offset
+
numel
])
offset
+=
numel
filled
=
0
for
g
in
grads
:
sz
=
g
.
numel
()
*
g
.
element_size
()
if
sz
>
buffer_size
:
# grad is bigger than buffer, all-reduce and rescale directly
nccl
.
all_reduce
(
g
)
g
.
div_
(
grad_denom
)
elif
filled
+
sz
>
buffer_size
:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer
()
buffer
=
[
g
]
filled
=
sz
else
:
# add grad to buffer
buffer
.
append
(
g
)
filled
+=
sz
if
len
(
buffer
)
>
0
:
all_reduce_buffer
()
def
valid_step
(
self
,
samples
):
"""Do forward pass in parallel."""
...
...
@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
# forward pass
_sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
_sample_sizes
,
logging_outputs
,
ooms_fwd
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
,
eval
=
True
)
for
rank
in
range
(
self
.
num_replicas
)
])
assert
sum
(
ooms_fwd
)
==
0
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
...
...
@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if
sample
is
None
:
self
.
_sample
=
None
else
:
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
# clear the caching allocator if this is the largest sample we've seen
if
sample
[
'target'
].
size
(
0
)
>
self
.
_max_bsz_seen
:
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_sample
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment