Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
889d871b
"...git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "4726d5be84307a094f0e557272f9e0530ae84395"
Unverified
Commit
889d871b
authored
May 22, 2018
by
Raul Puri
Committed by
GitHub
May 22, 2018
Browse files
Create LARC.py
parent
2d5b71bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
0 deletions
+53
-0
apex/parallel/LARC.py
apex/parallel/LARC.py
+53
-0
No files found.
apex/parallel/LARC.py
0 → 100644
View file @
889d871b
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
from
torch.nn.parameter
import
Parameter
class
LARC
(
object
):
def
__init__
(
self
,
optimizer
,
trust_coefficient
=
0.02
,
epsilon
=
1e-8
):
self
.
param_groups
=
optimizer
.
param_groups
self
.
optim
=
optimizer
self
.
trust_coefficient
=
trust_coefficient
self
.
eps
=
epsilon
def
__getstate__
(
self
):
return
self
.
optim
.
__getstate__
()
def
__setstate__
(
self
,
state
):
self
.
optim
.
__setstate__
(
state
)
def
__repr__
(
self
):
return
self
.
optim
.
__repr__
()
def
state_dict
(
self
):
return
self
.
optim
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
self
.
optim
.
load_state_dict
(
state_dict
)
def
zero_grad
(
self
):
self
.
optim
.
zero_grad
()
def
add_param_group
(
self
,
param_group
):
self
.
optim
.
add_param_group
(
param_group
)
def
step
(
self
):
with
torch
.
no_grad
():
weight_decays
=
[]
for
group
in
self
.
optim
.
param_groups
:
# absorb weight decay control from optimizer
weight_decay
=
group
[
'weight_decay'
]
if
'weight_decay'
in
group
else
0
weight_decays
.
append
(
weight_decay
)
group
[
'weight_decay'
]
=
0
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
param_norm
=
torch
.
norm
(
p
.
data
)
# calculate adaptive lr + weight decay
adaptive_lr
=
(
param_norm
+
self
.
eps
)
/
(
torch
.
norm
(
p
.
grad
.
data
)
+
param_norm
*
weight_decay
+
self
.
eps
)
p
.
grad
.
data
+=
weight_decay
*
p
.
data
p
.
grad
.
data
*=
self
.
trust_coefficient
*
adaptive_lr
self
.
optim
.
step
()
# return weight decay control to optimizer
for
i
,
group
in
enumerate
(
self
.
optim
.
param_groups
):
group
[
'weight_decay'
]
=
weight_decays
[
i
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment