Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
beeeed17
Commit
beeeed17
authored
Oct 01, 2021
by
Vishnu Banna
Browse files
optimization package pr comments
parent
4682f5c8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
51 deletions
+55
-51
official/vision/beta/projects/yolo/optimization/sgd_torch.py
official/vision/beta/projects/yolo/optimization/sgd_torch.py
+52
-48
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+3
-3
No files found.
official/vision/beta/projects/yolo/optimization/sgd_torch.py
View file @
beeeed17
...
...
@@ -7,10 +7,10 @@ import tensorflow as tf
import
re
import
logging
__all__
=
[
'SGDTorch'
]
def
_var_key
(
var
):
try
:
from
keras.optimizer_v2.optimizer_v2
import
_var_key
except
:
def
_var_key
(
var
):
"""Key for representing a primary variable, for looking up slots.
In graph mode the name is derived from the var shared name.
In eager mode the name is derived from the var unique id.
...
...
@@ -113,6 +113,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
self
.
_wset
=
set
()
self
.
_bset
=
set
()
self
.
_oset
=
set
()
if
self
.
sim_torch
:
logging
.
info
(
f
"Pytorch SGD simulation: "
)
logging
.
info
(
f
"Weight Decay:
{
weight_decay
}
"
)
...
...
@@ -122,8 +124,15 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def
set_other_lr
(
self
,
lr
):
self
.
_set_hyper
(
"other_learning_rate"
,
lr
)
def
_search
(
self
,
var
,
keys
):
"""Search all all keys for matches. Return True on match."""
for
r
in
keys
:
if
re
.
search
(
r
,
var
.
name
)
is
not
None
:
return
True
return
False
def
search_and_set_variable_groups
(
self
,
variables
):
"""Search all variable for matches
o
t each group.
"""Search all variable for matches
a
t each group.
Args:
variables: List[tf.Variable] from model.trainable_variables
...
...
@@ -132,27 +141,20 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
biases
=
[]
others
=
[]
def
search
(
var
,
keys
):
"""Search all all keys for matches. Return True on match."""
for
r
in
keys
:
if
re
.
search
(
r
,
var
.
name
)
is
not
None
:
return
True
return
False
for
var
in
variables
:
# search for weights
if
search
(
var
,
self
.
_weight_keys
):
if
self
.
search
(
var
,
self
.
_weight_keys
):
weights
.
append
(
var
)
continue
# search for biases
if
search
(
var
,
self
.
_bias_keys
):
if
self
.
search
(
var
,
self
.
_bias_keys
):
biases
.
append
(
var
)
continue
# if all searches fail, add to other group
others
.
append
(
var
)
self
.
set_variable_groups
(
weights
,
biases
,
others
)
return
return
weights
,
biases
,
others
def
set_variable_groups
(
self
,
weights
,
biases
,
others
):
"""Alterantive to search and set allowing user to manually set each group.
...
...
@@ -181,6 +183,21 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
self
.
_variables_set
=
True
return
def
_get_variable_group
(
self
,
var
,
coefficients
):
if
self
.
_variables_set
:
# check which groups hold which varaibles, preset.
if
(
_var_key
(
var
)
in
self
.
_wset
):
return
True
,
False
,
False
elif
(
_var_key
(
var
)
in
self
.
_bset
):
return
False
,
True
,
False
else
:
# search the variables at run time.
if
self
.
_search
(
var
,
self
.
_weight_keys
):
return
True
,
False
,
False
elif
self
.
_search
(
var
,
self
.
_bias_keys
):
return
False
,
True
,
False
return
False
,
False
,
True
def
_create_slots
(
self
,
var_list
):
"""Create a momentum variable for each variable."""
if
self
.
_momentum
:
...
...
@@ -189,11 +206,6 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
if
var
.
trainable
:
self
.
add_slot
(
var
,
"momentum"
)
if
not
self
.
_variables_set
:
# Fall back to automatically set the variables in case the user did not.
self
.
search_and_set_variable_groups
(
var_list
)
self
.
_variables_set
=
False
def
_get_momentum
(
self
,
iteration
):
"""Get the momentum value."""
momentum
=
self
.
_get_hyper
(
"momentum"
)
...
...
@@ -239,7 +251,7 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
return
apply_state
[(
var_device
,
var_dtype
)]
def
_apply_tf
(
self
,
grad
,
var
,
weight_decay
,
momentum
,
lr
):
"""Uses Tensorflow Optimizer with Weight decay SGDW."""
def
decay_op
(
var
,
learning_rate
,
wd
):
if
self
.
_weight_decay
and
wd
>
0
:
return
var
.
assign_sub
(
...
...
@@ -263,6 +275,7 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
var
=
var
.
handle
,
alpha
=
lr
,
delta
=
grad
,
use_locking
=
self
.
_use_locking
)
def
_apply
(
self
,
grad
,
var
,
weight_decay
,
momentum
,
lr
):
"""Uses Pytorch Optimizer with Weight decay SGDW."""
dparams
=
grad
groups
=
[]
...
...
@@ -288,19 +301,12 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
groups
.
append
(
weight_update
)
return
tf
.
group
(
*
groups
)
def
_get_vartype
(
self
,
var
,
coefficients
):
if
(
_var_key
(
var
)
in
self
.
_wset
):
return
True
,
False
,
False
elif
(
_var_key
(
var
)
in
self
.
_bset
):
return
False
,
True
,
False
return
False
,
False
,
True
def
_run_sgd
(
self
,
grad
,
var
,
apply_state
=
None
):
var_device
,
var_dtype
=
var
.
device
,
var
.
dtype
.
base_dtype
coefficients
=
((
apply_state
or
{}).
get
((
var_device
,
var_dtype
))
or
self
.
_fallback_apply_state
(
var_device
,
var_dtype
))
weights
,
bias
,
others
=
self
.
_get_var
type
(
var
,
coefficients
)
weights
,
bias
,
others
=
self
.
_get_var
iable_group
(
var
,
coefficients
)
weight_decay
=
tf
.
zeros_like
(
coefficients
[
"weight_decay"
])
lr
=
coefficients
[
"lr_t"
]
if
weights
:
...
...
@@ -314,8 +320,6 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
lr
=
coefficients
[
"other_lr_t"
]
momentum
=
coefficients
[
"momentum"
]
tf
.
print
(
lr
)
if
self
.
sim_torch
:
return
self
.
_apply
(
grad
,
var
,
weight_decay
,
momentum
,
lr
)
else
:
...
...
official/vision/beta/projects/yolo/tasks/yolo.py
View file @
beeeed17
...
...
@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
.
set_bias_lr
(
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
weights
,
biases
,
others
=
self
.
_model
.
get_weight_groups
(
self
.
_model
.
trainable_variables
)
optimizer
.
set_variable_groups
(
weights
,
biases
,
others
)
#
weights, biases, others = self._model.get_weight_groups(
#
self._model.trainable_variables)
#
optimizer.set_variable_groups(weights, biases, others)
else
:
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
opt_factory
.
_use_ema
=
ema
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment