Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7c07e87c
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e4b8f173b97731686e290b2eb98e7f5df2b1b322"
Commit
7c07e87c
authored
May 23, 2018
by
Myle Ott
Browse files
All-reduce in FP16
parent
fc312d28
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
12 deletions
+21
-12
fairseq/fp16_trainer.py
fairseq/fp16_trainer.py
+21
-12
No files found.
fairseq/fp16_trainer.py
View file @
7c07e87c
...
@@ -11,7 +11,7 @@ Train a network on multiple GPUs.
...
@@ -11,7 +11,7 @@ Train a network on multiple GPUs.
import
torch
import
torch
from
fairseq
import
optim
from
fairseq
import
optim
,
utils
from
fairseq.meters
import
AverageMeter
from
fairseq.meters
import
AverageMeter
from
fairseq.optim
import
lr_scheduler
from
fairseq.optim
import
lr_scheduler
from
fairseq.trainer
import
Trainer
from
fairseq.trainer
import
Trainer
...
@@ -105,8 +105,26 @@ class FP16Trainer(Trainer):
...
@@ -105,8 +105,26 @@ class FP16Trainer(Trainer):
# undo effect of dynamic loss scaling on gradients
# undo effect of dynamic loss scaling on gradients
grad_denom
*=
self
.
scaler
.
loss_scale
grad_denom
*=
self
.
scaler
.
loss_scale
# all-reduce and rescale gradients
if
self
.
args
.
distributed_world_size
>
1
:
grad_norm
=
super
().
_all_reduce_and_rescale
(
grad_denom
)
# flatten grads into a single buffer
flat_grads
=
self
.
_flat_grads
=
self
.
_get_flat_grads
(
self
.
_flat_grads
)
# scale gradients to avoid overflow in all-reduce
flat_grads
.
div_
(
self
.
args
.
distributed_world_size
)
grad_denom
/=
self
.
args
.
distributed_world_size
# all-reduce flat grads
torch
.
distributed
.
all_reduce
(
flat_grads
)
# copy grads back to FP32
self
.
fp32_params
.
grad
.
data
.
copy_
(
flat_grads
)
else
:
# single worker: copy grads directly to FP32
self
.
_get_flat_grads
(
out
=
self
.
fp32_params
.
grad
.
data
)
# rescale and clip grads
self
.
fp32_params
.
grad
.
data
.
div_
(
grad_denom
)
grad_norm
=
utils
.
clip_grad_norm_
(
self
.
fp32_params
.
grad
.
data
,
self
.
args
.
clip_norm
)
# detect overflow and adjust loss scale
# detect overflow and adjust loss scale
overflow
=
DynamicLossScaler
.
has_overflow
(
grad_norm
)
overflow
=
DynamicLossScaler
.
has_overflow
(
grad_norm
)
...
@@ -116,15 +134,6 @@ class FP16Trainer(Trainer):
...
@@ -116,15 +134,6 @@ class FP16Trainer(Trainer):
return
grad_norm
return
grad_norm
def
_get_flat_grads
(
self
,
out
=
None
):
if
out
is
None
:
out
=
self
.
fp32_params
.
grad
return
super
().
_get_flat_grads
(
out
)
def
_set_flat_grads
(
self
,
new_grads
):
# no-op
assert
new_grads
.
data_ptr
()
==
self
.
fp32_params
.
grad
.
data
.
data_ptr
()
def
_opt
(
self
):
def
_opt
(
self
):
# take an optimization step using the FP32 params and grads
# take an optimization step using the FP32 params and grads
super
().
_opt
()
super
().
_opt
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment