Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9bb71066
"vscode:/vscode.git/clone" did not exist on "5c9e1e285e50c7be6cbcec04c47b4f0b929ede85"
Commit
9bb71066
authored
May 06, 2020
by
Thor Johnsen
Browse files
Revert regular contrib fused adam optimizer
parent
7e3536dd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
98 deletions
+22
-98
apex/contrib/optimizers/fused_adam.py
apex/contrib/optimizers/fused_adam.py
+22
-98
No files found.
apex/contrib/optimizers/fused_adam.py
View file @
9bb71066
...
@@ -61,38 +61,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -61,38 +61,7 @@ class FusedAdam(torch.optim.Optimizer):
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_overflow_buf
.
item
()
self
.
_overflow_buf
.
zero_
()
return
has_overflow
@
property
def
peek_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_overflow_buf
.
item
()
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if
start
>=
0
and
start
<
end
:
out_p
=
output_params
[
start
:
end
]
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
,
allow_undo
=
False
):
"""Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
Arguments:
...
@@ -106,22 +75,11 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -106,22 +75,11 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None)
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
"""
"""
loss
=
None
loss
=
None
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
allow_undo
,
False
)
if
allow_undo
and
self
.
peek_overflow
:
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
False
,
True
)
return
loss
def
_step
(
self
,
grads
,
output_params
,
scale
,
grad_norms
,
check_overflow
,
undo
):
if
check_overflow
:
modified_params
=
[]
if
hasattr
(
self
,
"_amp_stash"
):
if
hasattr
(
self
,
"_amp_stash"
):
grads
=
self
.
_amp_stash
.
grads
grads
=
self
.
_amp_stash
.
grads
output_params
=
self
.
_amp_stash
.
output_params
output_params
=
self
.
_amp_stash
.
output_params
...
@@ -172,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -172,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
tensorlists
=
[[],[],[],[],[]]
tensorlists
=
[[],[],[],[],[]]
else
:
else
:
tensorlists
=
[[],[],[],[]]
tensorlists
=
[[],[],[],[]]
tensordevice
=
None
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
):
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
...
@@ -185,53 +144,34 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -185,53 +144,34 @@ class FusedAdam(torch.optim.Optimizer):
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
# State initialization
# State initialization
if
undo
:
if
len
(
state
)
==
0
:
assert
(
len
(
state
)
>
0
),
"Adam undo called with empty optimizer state"
state
[
'step'
]
=
0
else
:
# Exponential moving average of gradient values
if
len
(
state
)
==
0
:
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'step'
]
=
0
# Exponential moving average of squared gradient values
# Exponential moving average of gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
state
[
'step'
]
+=
1
step
=
state
[
'step'
]
state
[
'step'
]
-=
1
else
:
state
[
'step'
]
+=
1
step
=
state
[
'step'
]
if
not
undo
:
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
if
check_overflow
:
modified_params
.
append
(
out_p
)
if
self
.
_use_multi_tensor
:
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
not
undo
and
output_param
is
not
None
:
if
output_param
is
not
None
:
pl
.
append
(
out_p
)
pl
.
append
(
out_p
)
for
tl
,
t
in
zip
(
tensorlists
,
pl
):
for
tl
,
t
in
zip
(
tensorlists
,
pl
):
tl
.
append
(
t
)
tl
.
append
(
t
)
if
tensordevice
is
None
:
tensordevice
=
p
.
device
elif
tensordevice
!=
p
.
device
:
raise
RuntimeError
(
'FusedAdam does not support use_mt with tensors on multiple device'
)
else
:
else
:
if
undo
:
with
torch
.
cuda
.
device
(
p
.
device
):
fused_adam_cuda
.
adam_undo
(
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
else
:
fused_adam_cuda
.
adam
(
p
.
data
,
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
out_p
,
exp_avg
,
exp_avg
,
...
@@ -242,27 +182,13 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -242,27 +182,13 @@ class FusedAdam(torch.optim.Optimizer):
beta2
,
beta2
,
group
[
'eps'
],
group
[
'eps'
],
combined_scale
,
combined_scale
,
st
ep
,
st
ate
[
'step'
]
,
self
.
eps_mode
,
self
.
eps_mode
,
bias_correction
,
bias_correction
,
group
[
'weight_decay'
])
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
if
self
.
_use_multi_tensor
:
if
undo
:
with
torch
.
cuda
.
device
(
tensordevice
):
multi_tensor_applier
(
fused_adam_cuda
.
adam_undo_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
step
,
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
else
:
multi_tensor_applier
(
multi_tensor_applier
(
fused_adam_cuda
.
adam_mt
,
fused_adam_cuda
.
adam_mt
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
...
@@ -272,11 +198,9 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -272,11 +198,9 @@ class FusedAdam(torch.optim.Optimizer):
beta2
,
beta2
,
group
[
'eps'
],
group
[
'eps'
],
combined_scale
,
combined_scale
,
st
ep
,
st
ate
[
'step'
]
,
self
.
eps_mode
,
self
.
eps_mode
,
bias_correction
,
bias_correction
,
group
[
'weight_decay'
])
group
[
'weight_decay'
])
if
check_overflow
:
return
loss
for
i
,
out_p
in
enumerate
(
modified_params
):
self
.
strided_check_finite
(
out_p
,
stride
=
out_p
.
numel
(),
start
=
0
,
end
=
out_p
.
numel
(),
clear
=
True
if
i
==
0
else
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment