Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a0042113
Unverified
Commit
a0042113
authored
Oct 20, 2020
by
Min Xu
Committed by
GitHub
Oct 20, 2020
Browse files
[cleanup] mypy adascale (#149)
- close #143
parent
58e97aa6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
22 deletions
+39
-22
fairscale/optim/__init__.py
fairscale/optim/__init__.py
+1
-1
fairscale/optim/adascale.py
fairscale/optim/adascale.py
+33
-20
stubs/torch/autograd/__init__.pyi
stubs/torch/autograd/__init__.pyi
+5
-1
No files found.
fairscale/optim/__init__.py
View file @
a0042113
...
@@ -11,6 +11,6 @@ try:
...
@@ -11,6 +11,6 @@ try:
from
.adam
import
Adam
,
Precision
from
.adam
import
Adam
,
Precision
except
ImportError
:
# pragma: no cover
except
ImportError
:
# pragma: no cover
pass
# pragma: no cover
pass
# pragma: no cover
from
.adascale
import
AdaScale
# type: ignore
from
.adascale
import
AdaScale
from
.grad_scaler
import
GradScaler
from
.grad_scaler
import
GradScaler
from
.oss
import
OSS
from
.oss
import
OSS
fairscale/optim/adascale.py
View file @
a0042113
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2020 Petuum, Inc. All Rights Reserved.
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
#
# Redistribution and use in source and binary forms, with or without
# Redistribution and use in source and binary forms, with or without
...
@@ -26,9 +31,8 @@
...
@@ -26,9 +31,8 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# POSSIBILITY OF SUCH DAMAGE.
# type: ignore
import
functools
import
functools
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
numpy
as
np
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
...
@@ -67,11 +71,18 @@ class AdaScale(object):
...
@@ -67,11 +71,18 @@ class AdaScale(object):
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
"""
"""
def
__init__
(
self
,
optimizer
,
world_size
=
None
,
scale
=
None
,
smoothing
=
0.999
,
patch_optimizer
=
False
):
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
world_size
:
Optional
[
int
]
=
None
,
scale
:
Optional
[
float
]
=
None
,
smoothing
:
float
=
0.999
,
patch_optimizer
:
bool
=
False
,
):
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_optimizer_step
=
optimizer
.
step
self
.
_optimizer_step
=
optimizer
.
step
self
.
_local_grad_sqr
=
None
self
.
_local_grad_sqr
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_world_size
=
world_size
if
world_size
is
not
None
else
torch
.
distributed
.
get_world_size
()
self
.
_world_size
:
int
=
world_size
if
world_size
is
not
None
else
torch
.
distributed
.
get_world_size
()
if
self
.
_world_size
<=
1
:
if
self
.
_world_size
<=
1
:
raise
RuntimeError
(
"AdaScale does not support a single worker."
)
raise
RuntimeError
(
"AdaScale does not support a single worker."
)
...
@@ -96,11 +107,11 @@ class AdaScale(object):
...
@@ -96,11 +107,11 @@ class AdaScale(object):
self
.
_smoothing
=
smoothing
self
.
_smoothing
=
smoothing
@
property
@
property
def
state
(
self
):
def
state
(
self
)
->
Dict
[
str
,
np
.
ndarray
]
:
return
self
.
_optimizer
.
state
[
"adascale"
]
return
self
.
_optimizer
.
state
[
"adascale"
]
@
property
@
property
def
scale
(
self
):
def
scale
(
self
)
->
float
:
"""
"""
The scaling factor of the current batch size, relative to the baseline
The scaling factor of the current batch size, relative to the baseline
batch size when training with a single worker. For example, if the
batch size when training with a single worker. For example, if the
...
@@ -109,7 +120,7 @@ class AdaScale(object):
...
@@ -109,7 +120,7 @@ class AdaScale(object):
"""
"""
return
self
.
_scale
return
self
.
_scale
def
set_scale
(
self
,
scale
)
:
def
set_scale
(
self
,
scale
:
float
)
->
None
:
"""
"""
Set the scaling factor of the current batch size. It is up to the
Set the scaling factor of the current batch size. It is up to the
application to invoke this function to make sure that AdaScale's
application to invoke this function to make sure that AdaScale's
...
@@ -120,7 +131,7 @@ class AdaScale(object):
...
@@ -120,7 +131,7 @@ class AdaScale(object):
"""
"""
self
.
_scale
=
scale
self
.
_scale
=
scale
def
grad_sqr_avg
(
self
):
def
grad_sqr_avg
(
self
)
->
float
:
"""
"""
Current estimate of the squared l2-norm of the true gradient (sigma
Current estimate of the squared l2-norm of the true gradient (sigma
squared in the AdaScale paper).
squared in the AdaScale paper).
...
@@ -129,7 +140,7 @@ class AdaScale(object):
...
@@ -129,7 +140,7 @@ class AdaScale(object):
"""
"""
return
np
.
sum
(
self
.
state
[
"grad_sqr_avg"
])
return
np
.
sum
(
self
.
state
[
"grad_sqr_avg"
])
def
grad_var_avg
(
self
):
def
grad_var_avg
(
self
)
->
float
:
"""
"""
Current estimate of the trace of the covariance of the true gradient
Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper).
(mu squared in the AdaScale paper).
...
@@ -138,7 +149,7 @@ class AdaScale(object):
...
@@ -138,7 +149,7 @@ class AdaScale(object):
"""
"""
return
np
.
sum
(
self
.
state
[
"grad_var_avg"
])
return
np
.
sum
(
self
.
state
[
"grad_var_avg"
])
def
gain
(
self
,
scale
=
None
)
:
def
gain
(
self
,
scale
:
Optional
[
float
]
=
None
)
->
float
:
"""
"""
Current estimate of the AdaScale gain ratio (r_t).
Current estimate of the AdaScale gain ratio (r_t).
...
@@ -152,7 +163,7 @@ class AdaScale(object):
...
@@ -152,7 +163,7 @@ class AdaScale(object):
sqr
=
self
.
grad_sqr_avg
()
sqr
=
self
.
grad_sqr_avg
()
return
(
var
+
sqr
)
/
(
var
/
scale
+
sqr
)
return
(
var
+
sqr
)
/
(
var
/
scale
+
sqr
)
def
_update_avg
(
self
,
name
,
value
,
factor
)
:
def
_update_avg
(
self
,
name
:
str
,
value
:
float
,
factor
:
float
)
->
None
:
biased
=
self
.
state
.
get
(
name
+
"_biased"
,
0.0
)
biased
=
self
.
state
.
get
(
name
+
"_biased"
,
0.0
)
unbias
=
self
.
state
.
get
(
name
+
"_unbias"
,
0.0
)
unbias
=
self
.
state
.
get
(
name
+
"_unbias"
,
0.0
)
biased
=
factor
*
biased
+
(
1.0
-
factor
)
*
value
biased
=
factor
*
biased
+
(
1.0
-
factor
)
*
value
...
@@ -161,7 +172,7 @@ class AdaScale(object):
...
@@ -161,7 +172,7 @@ class AdaScale(object):
self
.
state
[
name
+
"_unbias"
]
=
unbias
self
.
state
[
name
+
"_unbias"
]
=
unbias
self
.
state
[
name
]
=
biased
/
unbias
self
.
state
[
name
]
=
biased
/
unbias
def
_backward_hook
(
self
,
idx
,
grad
)
:
def
_backward_hook
(
self
,
idx
:
int
,
grad
:
torch
.
Tensor
)
->
None
:
# This method should be invoked once for each parameter during the
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
# backward pass, before gradients are synchronized between world_size.
if
self
.
_local_grad_sqr
is
None
:
if
self
.
_local_grad_sqr
is
None
:
...
@@ -170,7 +181,7 @@ class AdaScale(object):
...
@@ -170,7 +181,7 @@ class AdaScale(object):
self
.
_final_callback_queued
=
False
self
.
_final_callback_queued
=
False
Variable
.
_execution_engine
.
queue_callback
(
self
.
_queue_callback
)
Variable
.
_execution_engine
.
queue_callback
(
self
.
_queue_callback
)
def
_queue_callback
(
self
):
def
_queue_callback
(
self
)
->
None
:
# This method should be invoked after the entire backward pass. We want
# This method should be invoked after the entire backward pass. We want
# to make sure self._final_callback is invoked once, only after all
# to make sure self._final_callback is invoked once, only after all
# gradients have been synchronized between each worker. However, the
# gradients have been synchronized between each worker. However, the
...
@@ -183,10 +194,11 @@ class AdaScale(object):
...
@@ -183,10 +194,11 @@ class AdaScale(object):
self
.
_final_callback_queued
=
True
self
.
_final_callback_queued
=
True
Variable
.
_execution_engine
.
queue_callback
(
self
.
_final_callback
)
Variable
.
_execution_engine
.
queue_callback
(
self
.
_final_callback
)
def
_final_callback
(
self
):
def
_final_callback
(
self
)
->
None
:
# This method should be invoked once for each backward pass, after
# This method should be invoked once for each backward pass, after
# gradients have been synchronized between each worker.
# gradients have been synchronized between each worker.
self
.
_final_callback_queued
=
False
self
.
_final_callback_queued
=
False
assert
isinstance
(
self
.
_local_grad_sqr
,
torch
.
Tensor
)
torch
.
distributed
.
all_reduce
(
self
.
_local_grad_sqr
/
self
.
_world_size
)
torch
.
distributed
.
all_reduce
(
self
.
_local_grad_sqr
/
self
.
_world_size
)
local_grad_sqr
=
self
.
_local_grad_sqr
.
cpu
().
numpy
()
local_grad_sqr
=
self
.
_local_grad_sqr
.
cpu
().
numpy
()
total_grad_sqr
=
np
.
array
(
total_grad_sqr
=
np
.
array
(
...
@@ -201,7 +213,7 @@ class AdaScale(object):
...
@@ -201,7 +213,7 @@ class AdaScale(object):
self
.
_update_avg
(
"grad_var_avg"
,
grad_var
,
theta
)
self
.
_update_avg
(
"grad_var_avg"
,
grad_var
,
theta
)
self
.
_local_grad_sqr
=
None
self
.
_local_grad_sqr
=
None
def
step
(
self
,
*
args
,
**
kwargs
)
:
def
step
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Optional
[
float
]
:
"""
"""
Run one optimizer step using Adascale. Essentially just invokes
Run one optimizer step using Adascale. Essentially just invokes
``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
...
@@ -216,17 +228,18 @@ class AdaScale(object):
...
@@ -216,17 +228,18 @@ class AdaScale(object):
grad_var
=
float
(
self
.
state
[
"grad_var_avg"
][
idx
])
grad_var
=
float
(
self
.
state
[
"grad_var_avg"
][
idx
])
gain
=
(
grad_var
+
grad_sqr
)
/
(
grad_var
/
self
.
_scale
+
grad_sqr
)
gain
=
(
grad_var
+
grad_sqr
)
/
(
grad_var
/
self
.
_scale
+
grad_sqr
)
param_group
[
"lr"
]
=
gain
*
param_group
[
"lr"
]
param_group
[
"lr"
]
=
gain
*
param_group
[
"lr"
]
self
.
_optimizer_step
(
*
args
,
**
kwargs
)
res
=
self
.
_optimizer_step
(
*
args
,
**
kwargs
)
for
lr
,
param_group
in
zip
(
initial_lr
,
self
.
_optimizer
.
param_groups
):
for
lr
,
param_group
in
zip
(
initial_lr
,
self
.
_optimizer
.
param_groups
):
param_group
[
"lr"
]
=
lr
param_group
[
"lr"
]
=
lr
return
res
def
patch_optimizer
(
self
):
def
patch_optimizer
(
self
)
->
None
:
"""
"""
Monkey-patch the optimizer's step function with :meth:`AdaScale.step`.
Monkey-patch the optimizer's step function with :meth:`AdaScale.step`.
"""
"""
@
functools
.
wraps
(
self
.
_optimizer
.
step
)
@
functools
.
wraps
(
self
.
_optimizer
.
step
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Optional
[
float
]
:
return
self
.
step
(
*
args
,
**
kwargs
)
return
self
.
step
(
*
args
,
**
kwargs
)
self
.
_optimizer
.
step
=
wrapper
setattr
(
self
.
_optimizer
,
"
step
"
,
wrapper
)
stubs/torch/autograd/__init__.pyi
View file @
a0042113
...
@@ -6,9 +6,13 @@ from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
...
@@ -6,9 +6,13 @@ from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
set_grad_enabled as set_grad_enabled
from .profiler import record_function
from .profiler import record_function
# This is defined in CPP in PyTorch source
class ImperativeEngine:
def queue_callback(self, callback: Callable[..., None]): ...
# TODO make Variable and Function more precise
# TODO make Variable and Function more precise
class Variable:
class Variable:
...
_execution_engine: ImperativeEngine
class Function:
class Function:
@staticmethod
@staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment