Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
24ade5b8
Commit
24ade5b8
authored
Sep 09, 2021
by
Vishnu Banna
Browse files
yolo_model update
parent
e528aa76
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
16 deletions
+42
-16
official/vision/beta/projects/yolo/modeling/yolo_model.py
official/vision/beta/projects/yolo/modeling/yolo_model.py
+42
-16
No files found.
official/vision/beta/projects/yolo/modeling/yolo_model.py
View file @
24ade5b8
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
#
S
tatic base Yolo Models that do not require configuration
#
s
tatic base Yolo Models that do not require configuration
# similar to a backbone model id.
# similar to a backbone model id.
# this is done greatly simplify the model config
# this is done greatly simplify the model config
...
@@ -80,31 +80,31 @@ class Yolo(tf.keras.Model):
...
@@ -80,31 +80,31 @@ class Yolo(tf.keras.Model):
backbone
=
None
,
backbone
=
None
,
decoder
=
None
,
decoder
=
None
,
head
=
None
,
head
=
None
,
detection_generato
r
=
None
,
filte
r
=
None
,
**
kwargs
):
**
kwargs
):
"""Detection initialization function.
"""Detection initialization function.
Args:
Args:
backbone: `tf.keras.Model`
,
a backbone network.
backbone: `tf.keras.Model` a backbone network.
decoder: `tf.keras.Model`
,
a decoder network.
decoder: `tf.keras.Model` a decoder network.
head: `
Yolo
Head`, the
YOLO
head.
head: `
RetinaNet
Head`, the
RetinaNet
head.
detection_generator: `tf.keras.Model`,
the detection generator.
filter:
the detection generator.
**kwargs: keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
"""
"""
super
().
__init__
(
**
kwargs
)
super
(
Yolo
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
self
.
_config_dict
=
{
"
backbone
"
:
backbone
,
'
backbone
'
:
backbone
,
"
decoder
"
:
decoder
,
'
decoder
'
:
decoder
,
"
head
"
:
head
,
'
head
'
:
head
,
"detection_generator"
:
detection_generato
r
'filter'
:
filte
r
}
}
# model components
# model components
self
.
_backbone
=
backbone
self
.
_backbone
=
backbone
self
.
_decoder
=
decoder
self
.
_decoder
=
decoder
self
.
_head
=
head
self
.
_head
=
head
self
.
_detection_generator
=
detection_generator
self
.
_filter
=
filter
return
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
maps
=
self
.
_backbone
(
inputs
)
maps
=
self
.
_backbone
(
inputs
)
...
@@ -114,7 +114,7 @@ class Yolo(tf.keras.Model):
...
@@ -114,7 +114,7 @@ class Yolo(tf.keras.Model):
return
{
"raw_output"
:
raw_predictions
}
return
{
"raw_output"
:
raw_predictions
}
else
:
else
:
# Post-processing.
# Post-processing.
predictions
=
self
.
_
detection_generato
r
(
raw_predictions
)
predictions
=
self
.
_
filte
r
(
raw_predictions
)
predictions
.
update
({
"raw_output"
:
raw_predictions
})
predictions
.
update
({
"raw_output"
:
raw_predictions
})
return
predictions
return
predictions
...
@@ -131,8 +131,8 @@ class Yolo(tf.keras.Model):
...
@@ -131,8 +131,8 @@ class Yolo(tf.keras.Model):
return
self
.
_head
return
self
.
_head
@
property
@
property
def
detection_generato
r
(
self
):
def
filte
r
(
self
):
return
self
.
_
detection_generato
r
return
self
.
_
filte
r
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config_dict
return
self
.
_config_dict
...
@@ -140,3 +140,29 @@ class Yolo(tf.keras.Model):
...
@@ -140,3 +140,29 @@ class Yolo(tf.keras.Model):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
):
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
return
cls
(
**
config
)
def
get_weight_groups
(
self
,
train_vars
):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias
=
[]
weights
=
[]
other
=
[]
for
var
in
train_vars
:
if
"bias"
in
var
.
name
:
bias
.
append
(
var
)
elif
"beta"
in
var
.
name
:
bias
.
append
(
var
)
elif
"kernel"
in
var
.
name
or
"weight"
in
var
.
name
:
weights
.
append
(
var
)
else
:
other
.
append
(
var
)
return
weights
,
bias
,
other
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment