Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
00dbe4b4
"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "d85bc4fb304bed80725c45cb23d08b8468cf469d"
Commit
00dbe4b4
authored
May 02, 2019
by
Michael Carilli
Browse files
test_fused_sgd.py passing
parent
72bce160
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
20 deletions
+33
-20
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+32
-19
apex/optimizers/fused_sgd.py
apex/optimizers/fused_sgd.py
+1
-1
No files found.
apex/amp/_process_optimizer.py
View file @
00dbe4b4
...
@@ -75,6 +75,8 @@ def lazy_init_with_master_weights(self):
...
@@ -75,6 +75,8 @@ def lazy_init_with_master_weights(self):
for
group
in
stash
.
fp32_from_fp32_groups
:
for
group
in
stash
.
fp32_from_fp32_groups
:
stash
.
all_fp32_from_fp32_params
+=
group
stash
.
all_fp32_from_fp32_params
+=
group
# all_fp16_grad_stash is only needed for fused optimizers.
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
stash
.
all_fp32_from_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_from_fp32_params
]
...
@@ -125,9 +127,7 @@ def post_backward_models_are_masters(scaler, params, stashed_grads):
...
@@ -125,9 +127,7 @@ def post_backward_models_are_masters(scaler, params, stashed_grads):
def
prepare_backward_with_master_weights
(
self
):
def
prepare_backward_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision:
...
@@ -145,6 +145,8 @@ def prepare_backward_with_master_weights(self):
...
@@ -145,6 +145,8 @@ def prepare_backward_with_master_weights(self):
def
post_backward_with_master_weights
(
self
,
scaler
):
def
post_backward_with_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
# This is a lot of python overhead...
# This is a lot of python overhead...
fp16_grads_needing_unscale
=
[]
fp16_grads_needing_unscale
=
[]
new_fp32_grads
=
[]
new_fp32_grads
=
[]
...
@@ -206,9 +208,7 @@ def lazy_init_no_master_weights(self):
...
@@ -206,9 +208,7 @@ def lazy_init_no_master_weights(self):
def
prepare_backward_no_master_weights
(
self
):
def
prepare_backward_no_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
...
@@ -224,6 +224,8 @@ def prepare_backward_no_master_weights(self):
...
@@ -224,6 +224,8 @@ def prepare_backward_no_master_weights(self):
def
post_backward_no_master_weights
(
self
,
scaler
):
def
post_backward_no_master_weights
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
...
@@ -238,13 +240,14 @@ def post_backward_no_master_weights(self, scaler):
...
@@ -238,13 +240,14 @@ def post_backward_no_master_weights(self, scaler):
def
prepare_backward_with_master_weights_FusedAdam
(
self
):
def
prepare_backward_with_master_weights_FusedAdam
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_with_master_weights_FusedAdam
(
self
,
scaler
):
def
post_backward_with_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
...
@@ -271,13 +274,14 @@ def post_backward_with_master_weights_FusedAdam(self, scaler):
...
@@ -271,13 +274,14 @@ def post_backward_with_master_weights_FusedAdam(self, scaler):
def
prepare_backward_no_master_weights_FusedAdam
(
self
):
def
prepare_backward_no_master_weights_FusedAdam
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_no_master_weights_FusedAdam
(
self
,
scaler
):
def
post_backward_no_master_weights_FusedAdam
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
output_params
=
None
...
@@ -296,9 +300,7 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
...
@@ -296,9 +300,7 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
...
@@ -314,6 +316,8 @@ def prepare_backward_with_master_weights_FusedSGD(self):
...
@@ -314,6 +316,8 @@ def prepare_backward_with_master_weights_FusedSGD(self):
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
...
@@ -329,6 +333,14 @@ def post_backward_no_master_weights_FusedSGD(self, scaler):
...
@@ -329,6 +333,14 @@ def post_backward_no_master_weights_FusedSGD(self, scaler):
post_backward_no_master_weights
(
self
,
scaler
)
post_backward_no_master_weights
(
self
,
scaler
)
def
_amp_lazy_init
(
self
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
_process_optimizer
(
optimizer
,
properties
):
def
_process_optimizer
(
optimizer
,
properties
):
if
hasattr
(
optimizer
,
"_amp_stash"
):
if
hasattr
(
optimizer
,
"_amp_stash"
):
raise
RuntimeError
(
"A given optimizer should only be passed through amp.initialize once."
)
raise
RuntimeError
(
"A given optimizer should only be passed through amp.initialize once."
)
...
@@ -342,7 +354,8 @@ def _process_optimizer(optimizer, properties):
...
@@ -342,7 +354,8 @@ def _process_optimizer(optimizer, properties):
for
name
in
(
"_lazy_init_maybe_master_weights"
,
for
name
in
(
"_lazy_init_maybe_master_weights"
,
"_master_params_to_model_params"
,
"_master_params_to_model_params"
,
"_prepare_amp_backward"
,
"_prepare_amp_backward"
,
"_post_amp_backward"
):
"_post_amp_backward"
,
"_amp_lazy_init"
):
if
hasattr
(
optimizer
,
name
):
if
hasattr
(
optimizer
,
name
):
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
raise
RuntimeError
(
"Incoming optimizer already has {} defined."
.
format
(
name
))
...
@@ -374,9 +387,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -374,9 +387,7 @@ def _process_optimizer(optimizer, properties):
old_zero_grad
=
optimizer
.
zero_grad
old_zero_grad
=
optimizer
.
zero_grad
def
new_zero_grad
(
self
):
def
new_zero_grad
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_amp_lazy_init
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
# Zero the model grads.
# Zero the model grads.
for
param
in
stash
.
all_fp16_params
:
for
param
in
stash
.
all_fp16_params
:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
...
@@ -426,4 +437,6 @@ def _process_optimizer(optimizer, properties):
...
@@ -426,4 +437,6 @@ def _process_optimizer(optimizer, properties):
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
post_backward_no_master_weights
,
optimizer
)
optimizer
.
_amp_lazy_init
=
types
.
MethodType
(
_amp_lazy_init
,
optimizer
)
return
optimizer
return
optimizer
apex/optimizers/fused_sgd.py
View file @
00dbe4b4
...
@@ -76,7 +76,7 @@ class FusedSGD(Optimizer):
...
@@ -76,7 +76,7 @@ class FusedSGD(Optimizer):
raise
RuntimeError
(
'apex.optimizers.FusedSGD requires cuda extensions'
)
raise
RuntimeError
(
'apex.optimizers.FusedSGD requires cuda extensions'
)
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
super
(
SGD
,
self
).
__setstate__
(
state
)
super
(
Fused
SGD
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'nesterov'
,
False
)
group
.
setdefault
(
'nesterov'
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment