Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a8ae1619
Commit
a8ae1619
authored
Apr 05, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 366939051
parent
fab47e9e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
29 deletions
+128
-29
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+9
-1
official/vision/beta/modeling/retinanet_model.py
official/vision/beta/modeling/retinanet_model.py
+39
-0
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+80
-28
No files found.
official/vision/beta/modeling/factory.py
View file @
a8ae1619
...
@@ -238,7 +238,15 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
...
@@ -238,7 +238,15 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
use_batched_nms
=
generator_config
.
use_batched_nms
)
use_batched_nms
=
generator_config
.
use_batched_nms
)
model
=
retinanet_model
.
RetinaNetModel
(
model
=
retinanet_model
.
RetinaNetModel
(
backbone
,
decoder
,
head
,
detection_generator_obj
)
backbone
,
decoder
,
head
,
detection_generator_obj
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
num_scales
=
model_config
.
anchor
.
num_scales
,
aspect_ratios
=
model_config
.
anchor
.
aspect_ratios
,
anchor_size
=
model_config
.
anchor
.
anchor_size
)
return
model
return
model
...
...
official/vision/beta/modeling/retinanet_model.py
View file @
a8ae1619
...
@@ -13,10 +13,13 @@
...
@@ -13,10 +13,13 @@
# limitations under the License.
# limitations under the License.
"""RetinaNet."""
"""RetinaNet."""
from
typing
import
List
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.ops
import
anchor
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
RetinaNetModel
(
tf
.
keras
.
Model
):
class
RetinaNetModel
(
tf
.
keras
.
Model
):
...
@@ -27,6 +30,11 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -27,6 +30,11 @@ class RetinaNetModel(tf.keras.Model):
decoder
,
decoder
,
head
,
head
,
detection_generator
,
detection_generator
,
min_level
:
Optional
[
int
]
=
None
,
max_level
:
Optional
[
int
]
=
None
,
num_scales
:
Optional
[
int
]
=
None
,
aspect_ratios
:
Optional
[
List
[
float
]]
=
None
,
anchor_size
:
Optional
[
float
]
=
None
,
**
kwargs
):
**
kwargs
):
"""Classification initialization function.
"""Classification initialization function.
...
@@ -35,6 +43,17 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -35,6 +43,17 @@ class RetinaNetModel(tf.keras.Model):
decoder: `tf.keras.Model` a decoder network.
decoder: `tf.keras.Model` a decoder network.
head: `RetinaNetHead`, the RetinaNet head.
head: `RetinaNetHead`, the RetinaNet head.
detection_generator: the detection generator.
detection_generator: the detection generator.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito
anchors added on each level. The number indicates the ratio of width to
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
on each scale level.
anchor_size: A number representing the scale of size of the base
anchor to the feature stride 2^level.
**kwargs: keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
"""
"""
super
(
RetinaNetModel
,
self
).
__init__
(
**
kwargs
)
super
(
RetinaNetModel
,
self
).
__init__
(
**
kwargs
)
...
@@ -43,6 +62,11 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -43,6 +62,11 @@ class RetinaNetModel(tf.keras.Model):
'decoder'
:
decoder
,
'decoder'
:
decoder
,
'head'
:
head
,
'head'
:
head
,
'detection_generator'
:
detection_generator
,
'detection_generator'
:
detection_generator
,
'min_level'
:
min_level
,
'max_level'
:
max_level
,
'num_scales'
:
num_scales
,
'aspect_ratios'
:
aspect_ratios
,
'anchor_size'
:
anchor_size
,
}
}
self
.
_backbone
=
backbone
self
.
_backbone
=
backbone
self
.
_decoder
=
decoder
self
.
_decoder
=
decoder
...
@@ -105,6 +129,21 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -105,6 +129,21 @@ class RetinaNetModel(tf.keras.Model):
outputs
.
update
({
'att_outputs'
:
raw_attributes
})
outputs
.
update
({
'att_outputs'
:
raw_attributes
})
return
outputs
return
outputs
else
:
else
:
# Generate anchor boxes for this batch if not provided.
if
anchor_boxes
is
None
:
_
,
image_height
,
image_width
,
_
=
images
.
get_shape
().
as_list
()
anchor_boxes
=
anchor
.
Anchor
(
min_level
=
self
.
_config_dict
[
'min_level'
],
max_level
=
self
.
_config_dict
[
'max_level'
],
num_scales
=
self
.
_config_dict
[
'num_scales'
],
aspect_ratios
=
self
.
_config_dict
[
'aspect_ratios'
],
anchor_size
=
self
.
_config_dict
[
'anchor_size'
],
image_size
=
(
image_height
,
image_width
)).
multilevel_boxes
for
l
in
anchor_boxes
:
anchor_boxes
[
l
]
=
tf
.
tile
(
tf
.
expand_dims
(
anchor_boxes
[
l
],
axis
=
0
),
[
tf
.
shape
(
images
)[
0
],
1
,
1
,
1
])
# Post-processing.
# Post-processing.
final_results
=
self
.
detection_generator
(
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
...
...
official/vision/beta/modeling/retinanet_model_test.py
View file @
a8ae1619
...
@@ -33,37 +33,79 @@ from official.vision.beta.ops import anchor
...
@@ -33,37 +33,79 @@ from official.vision.beta.ops import anchor
class
RetinaNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
RetinaNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
3
,
3
,
7
,
3
,
[
1.0
],
50
,
False
,
256
,
4
,
256
,
32244949
),
{
'use_separable_conv'
:
True
,
'build_anchor_boxes'
:
True
,
'is_training'
:
False
,
'has_att_heads'
:
False
},
{
'use_separable_conv'
:
False
,
'build_anchor_boxes'
:
True
,
'is_training'
:
False
,
'has_att_heads'
:
False
},
{
'use_separable_conv'
:
False
,
'build_anchor_boxes'
:
False
,
'is_training'
:
False
,
'has_att_heads'
:
False
},
{
'use_separable_conv'
:
False
,
'build_anchor_boxes'
:
False
,
'is_training'
:
True
,
'has_att_heads'
:
False
},
{
'use_separable_conv'
:
False
,
'build_anchor_boxes'
:
True
,
'is_training'
:
True
,
'has_att_heads'
:
True
},
{
'use_separable_conv'
:
False
,
'build_anchor_boxes'
:
True
,
'is_training'
:
False
,
'has_att_heads'
:
True
},
)
)
def
test_num_params
(
self
,
def
test_build_model
(
self
,
use_separable_conv
,
build_anchor_boxes
,
num_classes
,
is_training
,
has_att_heads
):
min_level
,
num_classes
=
3
max_level
,
min_level
=
3
num_scales
,
max_level
=
7
aspect_ratios
,
num_scales
=
3
resnet_model_id
,
aspect_ratios
=
[
1.0
]
use_separable_conv
,
anchor_size
=
3
fpn_num_filters
,
fpn_num_filters
=
256
head_num_convs
,
head_num_convs
=
4
head_num_filters
,
head_num_filters
=
256
expected_num_params
):
num_anchors_per_location
=
num_scales
*
len
(
aspect_ratios
)
num_anchors_per_location
=
num_scales
*
len
(
aspect_ratios
)
image_size
=
384
image_size
=
384
images
=
np
.
random
.
rand
(
2
,
image_size
,
image_size
,
3
)
images
=
np
.
random
.
rand
(
2
,
image_size
,
image_size
,
3
)
image_shape
=
np
.
array
([[
image_size
,
image_size
],
[
image_size
,
image_size
]])
image_shape
=
np
.
array
([[
image_size
,
image_size
],
[
image_size
,
image_size
]])
if
build_anchor_boxes
:
anchor_boxes
=
anchor
.
Anchor
(
anchor_boxes
=
anchor
.
Anchor
(
min_level
=
min_level
,
min_level
=
min_level
,
max_level
=
max_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
3
,
anchor_size
=
anchor_size
,
image_size
=
(
image_size
,
image_size
)).
multilevel_boxes
image_size
=
(
image_size
,
image_size
)).
multilevel_boxes
for
l
in
anchor_boxes
:
for
l
in
anchor_boxes
:
anchor_boxes
[
l
]
=
tf
.
tile
(
anchor_boxes
[
l
]
=
tf
.
tile
(
tf
.
expand_dims
(
anchor_boxes
[
l
],
axis
=
0
),
[
2
,
1
,
1
,
1
])
tf
.
expand_dims
(
anchor_boxes
[
l
],
axis
=
0
),
[
2
,
1
,
1
,
1
])
else
:
anchor_boxes
=
None
backbone
=
resnet
.
ResNet
(
model_id
=
resnet_model_id
)
if
has_att_heads
:
attribute_heads
=
{
'depth'
:
(
'regression'
,
1
)}
else
:
attribute_heads
=
None
backbone
=
resnet
.
ResNet
(
model_id
=
50
)
decoder
=
fpn
.
FPN
(
decoder
=
fpn
.
FPN
(
input_specs
=
backbone
.
output_specs
,
input_specs
=
backbone
.
output_specs
,
min_level
=
min_level
,
min_level
=
min_level
,
...
@@ -74,6 +116,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -74,6 +116,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
min_level
=
min_level
,
min_level
=
min_level
,
max_level
=
max_level
,
max_level
=
max_level
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
attribute_heads
=
attribute_heads
,
num_anchors_per_location
=
num_anchors_per_location
,
num_anchors_per_location
=
num_anchors_per_location
,
use_separable_conv
=
use_separable_conv
,
use_separable_conv
=
use_separable_conv
,
num_convs
=
head_num_convs
,
num_convs
=
head_num_convs
,
...
@@ -84,10 +127,14 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -84,10 +127,14 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
backbone
=
backbone
,
backbone
=
backbone
,
decoder
=
decoder
,
decoder
=
decoder
,
head
=
head
,
head
=
head
,
detection_generator
=
generator
)
detection_generator
=
generator
,
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
anchor_size
)
_
=
model
(
images
,
image_shape
,
anchor_boxes
,
training
=
True
)
_
=
model
(
images
,
image_shape
,
anchor_boxes
,
training
=
is_training
)
self
.
assertEqual
(
expected_num_params
,
model
.
count_params
())
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
...
@@ -226,7 +273,12 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -226,7 +273,12 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
backbone
=
backbone
,
backbone
=
backbone
,
decoder
=
decoder
,
decoder
=
decoder
,
head
=
head
,
head
=
head
,
detection_generator
=
generator
)
detection_generator
=
generator
,
min_level
=
min_level
,
max_level
=
max_level
,
num_scales
=
num_scales
,
aspect_ratios
=
aspect_ratios
,
anchor_size
=
3
)
config
=
model
.
get_config
()
config
=
model
.
get_config
()
new_model
=
retinanet_model
.
RetinaNetModel
.
from_config
(
config
)
new_model
=
retinanet_model
.
RetinaNetModel
.
from_config
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment