Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d1988e3e
Commit
d1988e3e
authored
Oct 01, 2021
by
Vishnu Banna
Browse files
optimization package pr comments
parent
beeeed17
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
8 deletions
+10
-8
official/vision/beta/projects/yolo/optimization/sgd_torch.py
official/vision/beta/projects/yolo/optimization/sgd_torch.py
+7
-5
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+3
-3
No files found.
official/vision/beta/projects/yolo/optimization/sgd_torch.py
View file @
d1988e3e
...
@@ -126,9 +126,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
...
@@ -126,9 +126,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def
_search
(
self
,
var
,
keys
):
def
_search
(
self
,
var
,
keys
):
"""Search all all keys for matches. Return True on match."""
"""Search all all keys for matches. Return True on match."""
for
r
in
keys
:
if
keys
is
not
None
:
if
re
.
search
(
r
,
var
.
name
)
is
not
None
:
# variable group is not ignored so search for the keys.
return
True
for
r
in
keys
:
if
re
.
search
(
r
,
var
.
name
)
is
not
None
:
return
True
return
False
return
False
def
search_and_set_variable_groups
(
self
,
variables
):
def
search_and_set_variable_groups
(
self
,
variables
):
...
@@ -143,11 +145,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
...
@@ -143,11 +145,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
for
var
in
variables
:
for
var
in
variables
:
# search for weights
# search for weights
if
self
.
search
(
var
,
self
.
_weight_keys
):
if
self
.
_
search
(
var
,
self
.
_weight_keys
):
weights
.
append
(
var
)
weights
.
append
(
var
)
continue
continue
# search for biases
# search for biases
if
self
.
search
(
var
,
self
.
_bias_keys
):
if
self
.
_
search
(
var
,
self
.
_bias_keys
):
biases
.
append
(
var
)
biases
.
append
(
var
)
continue
continue
# if all searches fail, add to other group
# if all searches fail, add to other group
...
...
official/vision/beta/projects/yolo/tasks/yolo.py
View file @
d1988e3e
...
@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
...
@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
.
set_bias_lr
(
optimizer
.
set_bias_lr
(
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
#
weights, biases, others = self._model.get_weight_groups(
weights
,
biases
,
others
=
self
.
_model
.
get_weight_groups
(
#
self._model.trainable_variables)
self
.
_model
.
trainable_variables
)
#
optimizer.set_variable_groups(weights, biases, others)
optimizer
.
set_variable_groups
(
weights
,
biases
,
others
)
else
:
else
:
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
opt_factory
.
_use_ema
=
ema
opt_factory
.
_use_ema
=
ema
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment