Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f1e565f5
Commit
f1e565f5
authored
Mar 12, 2020
by
Thor Johnsen
Browse files
Modify fused_adam to take advantage of undo feature
parent
c659e564
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
36 deletions
+113
-36
apex/contrib/optimizers/fused_adam.py
apex/contrib/optimizers/fused_adam.py
+113
-36
No files found.
apex/contrib/optimizers/fused_adam.py
View file @
f1e565f5
...
...
@@ -61,6 +61,37 @@ class FusedAdam(torch.optim.Optimizer):
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_overflow_buf
.
item
()
self
.
_overflow_buf
.
zero_
()
return
has_overflow
@
property
def
peek_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_overflow_buf
.
item
()
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if
start
>=
0
and
start
<
end
:
out_p
=
output_params
[
start
:
end
]
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
"""Performs a single optimization step.
...
...
@@ -80,6 +111,13 @@ class FusedAdam(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
False
)
self
.
strided_check_finite
(
output_params
,
output_params
.
numel
(),
0
,
output_params
.
numel
())
if
self
.
peek_overflow
:
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
True
)
return
loss
def
_step
(
self
,
grads
,
output_params
,
scale
.,
grad_norms
,
undo
):
if
hasattr
(
self
,
"_amp_stash"
):
grads
=
self
.
_amp_stash
.
grads
output_params
=
self
.
_amp_stash
.
output_params
...
...
@@ -143,55 +181,94 @@ class FusedAdam(torch.optim.Optimizer):
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
undo
:
assert
(
len
(
state
)
>
0
),
"Adam undo called with empty optimizer state"
else
:
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
beta1
,
beta2
=
group
[
'betas'
]
state
[
'step'
]
+=
1
if
undo
:
step
=
state
[
'step'
]
state
[
'step'
]
-=
1
else
:
state
[
'step'
]
+=
1
step
=
state
[
'step'
]
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
if
not
undo
:
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
output_param
is
not
None
:
if
not
undo
and
output_param
is
not
None
:
pl
.
append
(
out_p
)
for
tl
,
t
in
zip
(
tensorlists
,
pl
):
tl
.
append
(
t
)
else
:
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
undo
:
fused_adam_cuda
.
adam_undo
(
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
else
:
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
multi_tensor_applier
(
fused_adam_cuda
.
adam_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
undo
:
multi_tensor_applier
(
fused_adam_cuda
.
adam_undo_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
else
:
multi_tensor_applier
(
fused_adam_cuda
.
adam_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment