Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c86a93db
Commit
c86a93db
authored
Oct 01, 2021
by
Vishnu Banna
Browse files
optimization package pr comments
parent
17f4ae11
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
29 deletions
+1
-29
official/vision/beta/projects/yolo/modeling/yolo_model.py
official/vision/beta/projects/yolo/modeling/yolo_model.py
+0
-26
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+1
-3
No files found.
official/vision/beta/projects/yolo/modeling/yolo_model.py
View file @
c86a93db
...
@@ -86,32 +86,6 @@ class Yolo(tf.keras.Model):
...
@@ -86,32 +86,6 @@ class Yolo(tf.keras.Model):
def
from_config
(
cls
,
config
):
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
return
cls
(
**
config
)
def
get_weight_groups
(
self
,
train_vars
):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias
=
[]
weights
=
[]
other
=
[]
for
var
in
train_vars
:
if
"bias"
in
var
.
name
:
bias
.
append
(
var
)
elif
"beta"
in
var
.
name
:
bias
.
append
(
var
)
elif
"kernel"
in
var
.
name
or
"weight"
in
var
.
name
:
weights
.
append
(
var
)
else
:
other
.
append
(
var
)
return
weights
,
bias
,
other
def
fuse
(
self
):
def
fuse
(
self
):
"""Fuses all Convolution and Batchnorm layers to get better latency."""
"""Fuses all Convolution and Batchnorm layers to get better latency."""
print
(
"Fusing Conv Batch Norm Layers."
)
print
(
"Fusing Conv Batch Norm Layers."
)
...
...
official/vision/beta/projects/yolo/tasks/yolo.py
View file @
c86a93db
...
@@ -388,9 +388,7 @@ class YoloTask(base_task.Task):
...
@@ -388,9 +388,7 @@ class YoloTask(base_task.Task):
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
.
set_bias_lr
(
optimizer
.
set_bias_lr
(
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
weights
,
biases
,
others
=
self
.
_model
.
get_weight_groups
(
optimizer
.
search_and_set_variable_groups
(
self
.
_model
.
trainable_variables
)
self
.
_model
.
trainable_variables
)
optimizer
.
set_variable_groups
(
weights
,
biases
,
others
)
else
:
else
:
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
opt_factory
.
_use_ema
=
ema
opt_factory
.
_use_ema
=
ema
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment