Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8599b854
Commit
8599b854
authored
Aug 12, 2019
by
Deyu Fu
Browse files
converge fused_sgd and sgd code(dtype support, fused kernel, wd_after)
parent
adad5996
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
54 deletions
+97
-54
apex/optimizers/sgd.py
apex/optimizers/sgd.py
+97
-54
No files found.
apex/optimizers/sgd.py
View file @
8599b854
import
torch
from
torch.optim
import
Optimizer
from
amp_C
import
multi_tensor_axpby
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
SGD
(
Optimizer
):
...
...
@@ -16,6 +15,8 @@ class SGD(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
...
...
@@ -40,7 +41,8 @@ class SGD(Optimizer):
"""
def
__init__
(
self
,
params
,
lr
=
0.1
,
momentum
=
0.
,
dampening
=
0.
,
weight_decay
=
0.
,
nesterov
=
False
):
weight_decay
=
0.
,
nesterov
=
False
,
wd_after_momentum
=
False
,
set_grad_none
=
True
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
...
...
@@ -50,6 +52,18 @@ class SGD(Optimizer):
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
)
self
.
wd_after_momentum
=
wd_after_momentum
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
import
amp_C
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_axpby
=
amp_C
.
multi_tensor_axpby
self
.
multi_tensor_sgd
=
amp_C
.
multi_tensor_sgd
else
:
raise
RuntimeError
(
'apex.optimizers.FusedSGD requires cuda extensions'
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
SGD
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -60,10 +74,29 @@ class SGD(Optimizer):
group
.
setdefault
(
'nesterov'
,
False
)
def
zero_grad
(
self
):
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
not
None
:
p
.
grad
.
fill_
(
0.33
)
if
self
.
set_grad_none
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
else
:
super
(
SGD
,
self
).
zero_grad
()
def
get_momentums
(
self
,
params
):
momentums
=
[]
first_run
=
True
for
p
in
params
:
param_state
=
self
.
state
[
p
]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if
'momentum_buffer'
not
in
param_state
:
first_run
=
True
buf
=
param_state
[
'momentum_buffer'
]
=
torch
.
zeros_like
(
p
.
data
)
momentums
.
append
(
buf
)
else
:
first_run
=
False
momentums
.
append
(
param_state
[
'momentum_buffer'
])
return
momentums
,
first_run
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
...
...
@@ -81,59 +114,69 @@ class SGD(Optimizer):
dampening
=
group
[
'dampening'
]
nesterov
=
group
[
'nesterov'
]
param_list
,
grad_list
,
momentum_list
=
[],
[],
[]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
for
group_dtype
in
[
torch
.
float16
,
torch
.
float32
]:
grad_list
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
group_dtype
and
p
.
grad
is
not
None
)]
if
len
(
grad_list
)
==
0
:
continue
# create lists for multi tensor apply
param_list
.
append
(
p
.
data
)
grad_list
.
append
(
p
.
grad
.
data
)
param_list
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
group_dtype
and
p
.
grad
is
not
None
)]
if
momentum
!=
0
:
param_state
=
self
.
state
[
p
]
if
'momentum_buffer'
not
in
param_state
:
buf
=
param_state
[
'momentum_buffer'
]
=
torch
.
clone
(
p
.
grad
.
data
).
detach
()
group
[
'init'
]
=
True
else
:
buf
=
param_state
[
'momentum_buffer'
]
group
[
'init'
]
=
False
momentum_list
.
append
(
buf
)
if
weight_decay
!=
0
:
multi_tensor_applier
(
multi_tensor_axpby
,
torch
.
cuda
.
IntTensor
([
0
]),
#dummy_overflow_buf,
[
grad_list
,
param_list
,
grad_list
],
1.
,
weight_decay
,
2
)
if
momentum
!=
0
:
if
not
group
[
'init'
]:
momentum_list
,
first_run
=
self
.
get_momentums
(
param_list
)
multi_tensor_applier
(
multi_tensor_axpby
,
torch
.
cuda
.
IntTensor
([
0
]),
#dummy_overflow_buf,
[
momentum_list
,
grad_list
,
momentum_list
],
self
.
multi_tensor_sgd
,
self
.
_dummy_overflow_buf
,
[
grad_list
,
param_list
,
momentum_list
],
weight_decay
,
momentum
,
1.
-
dampening
,
2
)
if
nesterov
:
dampening
,
group
[
'lr'
],
nesterov
,
first_run
,
self
.
wd_after_momentum
,
1.0
)
else
:
# show how to implement SGD using axpby, without writing new multi_tensor kernel
# only enabled now in no momentum case, since it saves creating momentum for us
# keep momentum != 0 code below for completeness
if
weight_decay
!=
0
and
not
self
.
wd_after_momentum
:
multi_tensor_applier
(
self
.
multi_tensor_axpby
,
self
.
_dummy_overflow_buf
,
[
grad_list
,
param_list
,
grad_list
],
1.
,
weight_decay
,
2
)
if
momentum
!=
0
:
# always False
if
not
first_run
:
multi_tensor_applier
(
self
.
multi_tensor_axpby
,
self
.
_dummy_overflow_buf
,
[
momentum_list
,
grad_list
,
momentum_list
],
momentum
,
1.
-
dampening
,
2
)
if
nesterov
:
multi_tensor_applier
(
self
.
multi_tensor_axpby
,
self
.
_dummy_overflow_buf
,
[
grad_list
,
momentum_list
,
grad_list
],
1.
,
momentum
,
2
)
else
:
grad_list
=
momentum_list
if
weight_decay
!=
0
and
self
.
wd_after_momentum
:
multi_tensor_applier
(
self
.
multi_tensor_axpby
,
self
.
_dummy_overflow_buf
,
[
grad_list
,
param_list
,
grad_list
],
1.
,
weight_decay
,
2
)
multi_tensor_applier
(
multi_tensor_axpby
,
torch
.
cuda
.
IntTensor
([
0
]),
#
dummy_overflow_buf,
[
g
ra
d
_list
,
momentum
_list
,
g
ra
d
_list
],
self
.
multi_tensor_axpby
,
self
.
_
dummy_overflow_buf
,
[
pa
ra
m
_list
,
grad
_list
,
pa
ra
m
_list
],
1.
,
momentum
,
-
group
[
'lr'
]
,
2
)
else
:
grad_list
=
momentum_list
multi_tensor_applier
(
multi_tensor_axpby
,
torch
.
cuda
.
IntTensor
([
0
]),
#dummy_overflow_buf,
[
param_list
,
grad_list
,
param_list
],
1.
,
-
group
[
'lr'
],
2
)
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment