Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
61b8a0fd
Commit
61b8a0fd
authored
Apr 11, 2019
by
Michael Carilli
Browse files
Rough cut, control flow should work for scaleout testing
parent
dda59354
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
9 deletions
+64
-9
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+59
-9
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+5
-0
No files found.
apex/amp/_process_optimizer.py
View file @
61b8a0fd
...
...
@@ -217,6 +217,55 @@ def post_backward_no_master_weights(self, scaler):
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
def
prepare_backward_with_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_with_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
self
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
in
group
]
for
group
in
self
.
fp16_groups
]
norm_groups
=
[]
skip
=
False
for
grad_group
in
stash
.
grads
:
norm
=
multi_tensor_applier
(
stash
.
multi_tensor_l2norm
,
stash
.
dummy_overflow_buf
,
[
grad_group
])
# Still syncing here for now.
norm
=
float
(
norm
)
norm_groups
.
append
(
norm
)
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
skip
=
True
if
skip
:
scaler
.
_overflow_buf
.
fill_
(
1.
)
scaler
.
_has_overflow
=
True
self
.
_amp_stash
.
grad_norms
=
norm_groups
def
prepare_backward_no_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_no_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
grad_norms
=
None
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
...
...
@@ -252,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if
multi_tensor_applier
.
available
:
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
if
properties
.
master_weights
:
...
...
@@ -261,16 +311,16 @@ def _process_optimizer(optimizer, properties):
optimizer
.
_master_params_to_model_params
=
types
.
MethodType
(
_master_params_to_model_params
,
optimizer
)
if
not
isinstance
(
optimizer
,
FusedAdam
):
old_step
=
optimizer
.
step
def
new_step
(
self
):
retval
=
old_step
()
old_step
=
optimizer
.
step
def
new_step
(
self
):
retval
=
old_step
()
if
not
isinstance
(
self
,
FusedAdam
):
self
.
_master_params_to_model_params
()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
return
retval
optimizer
.
step
=
types
.
MethodType
(
new_step
,
optimizer
)
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
return
retval
optimizer
.
step
=
types
.
MethodType
(
new_step
,
optimizer
)
old_zero_grad
=
optimizer
.
zero_grad
def
new_zero_grad
(
self
):
...
...
apex/optimizers/fused_adam.py
View file @
61b8a0fd
...
...
@@ -78,6 +78,11 @@ class FusedAdam(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
grads
=
self
.
_amp_stash
.
grads
output_params
=
self
.
_amp_stash
.
output_params
scale
=
self
.
_amp_stash
.
scale
grad_norms
=
self
.
_amp_stash
.
grad_norms
if
grads
is
None
:
grads_group
=
[
None
]
*
len
(
self
.
param_groups
)
# backward compatibility
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment