Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
33582301
Commit
33582301
authored
Oct 27, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
Oct 27, 2021
Browse files
Internal change
PiperOrigin-RevId: 405975992
parent
520ebe14
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
89 additions
and
35 deletions
+89
-35
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+1
-0
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+1
-0
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+4
-2
official/vision/beta/modeling/layers/detection_generator.py
official/vision/beta/modeling/layers/detection_generator.py
+58
-20
official/vision/beta/modeling/layers/detection_generator_test.py
...l/vision/beta/modeling/layers/detection_generator_test.py
+19
-11
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+6
-2
No files found.
official/vision/beta/configs/maskrcnn.py
View file @
33582301
...
@@ -133,6 +133,7 @@ class DetectionGenerator(hyperparams.Config):
...
@@ -133,6 +133,7 @@ class DetectionGenerator(hyperparams.Config):
max_num_detections
:
int
=
100
max_num_detections
:
int
=
100
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`
use_cpu_nms
:
bool
=
False
use_cpu_nms
:
bool
=
False
soft_nms_sigma
:
Optional
[
float
]
=
None
# Only works when nms_version='v1'.
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/configs/retinanet.py
View file @
33582301
...
@@ -114,6 +114,7 @@ class DetectionGenerator(hyperparams.Config):
...
@@ -114,6 +114,7 @@ class DetectionGenerator(hyperparams.Config):
max_num_detections
:
int
=
100
max_num_detections
:
int
=
100
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`.
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`.
use_cpu_nms
:
bool
=
False
use_cpu_nms
:
bool
=
False
soft_nms_sigma
:
Optional
[
float
]
=
None
# Only works when nms_version='v1'.
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/modeling/factory.py
View file @
33582301
...
@@ -198,7 +198,8 @@ def build_maskrcnn(
...
@@ -198,7 +198,8 @@ def build_maskrcnn(
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
max_num_detections
=
generator_config
.
max_num_detections
,
max_num_detections
=
generator_config
.
max_num_detections
,
nms_version
=
generator_config
.
nms_version
,
nms_version
=
generator_config
.
nms_version
,
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
use_cpu_nms
=
generator_config
.
use_cpu_nms
,
soft_nms_sigma
=
generator_config
.
soft_nms_sigma
)
if
model_config
.
include_mask
:
if
model_config
.
include_mask
:
mask_head
=
instance_heads
.
MaskHead
(
mask_head
=
instance_heads
.
MaskHead
(
...
@@ -301,7 +302,8 @@ def build_retinanet(
...
@@ -301,7 +302,8 @@ def build_retinanet(
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
max_num_detections
=
generator_config
.
max_num_detections
,
max_num_detections
=
generator_config
.
max_num_detections
,
nms_version
=
generator_config
.
nms_version
,
nms_version
=
generator_config
.
nms_version
,
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
use_cpu_nms
=
generator_config
.
use_cpu_nms
,
soft_nms_sigma
=
generator_config
.
soft_nms_sigma
)
model
=
retinanet_model
.
RetinaNetModel
(
model
=
retinanet_model
.
RetinaNetModel
(
backbone
,
backbone
,
...
...
official/vision/beta/modeling/layers/detection_generator.py
View file @
33582301
...
@@ -20,6 +20,7 @@ import tensorflow as tf
...
@@ -20,6 +20,7 @@ import tensorflow as tf
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
nms
from
official.vision.beta.ops
import
nms
from
official.vision.beta.ops
import
preprocess_ops
def
_generate_detections_v1
(
boxes
:
tf
.
Tensor
,
def
_generate_detections_v1
(
boxes
:
tf
.
Tensor
,
...
@@ -29,7 +30,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
...
@@ -29,7 +30,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
pre_nms_top_k
:
int
=
5000
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
):
max_num_detections
:
int
=
100
,
soft_nms_sigma
:
Optional
[
float
]
=
None
):
"""Generates the final detections given the model outputs.
"""Generates the final detections given the model outputs.
The implementation unrolls the batch dimension and process images one by one.
The implementation unrolls the batch dimension and process images one by one.
...
@@ -58,6 +60,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
...
@@ -58,6 +60,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
boxes overlap too much with respect to IOU.
boxes overlap too much with respect to IOU.
max_num_detections: A scalar representing maximum number of boxes retained
max_num_detections: A scalar representing maximum number of boxes retained
over all classes.
over all classes.
soft_nms_sigma: A `float` representing the sigma parameter for Soft NMS.
When soft_nms_sigma=0.0 (which is default), we fall back to standard NMS.
Returns:
Returns:
nms_boxes: A `float` type `tf.Tensor` of shape
nms_boxes: A `float` type `tf.Tensor` of shape
...
@@ -99,7 +103,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
...
@@ -99,7 +103,8 @@ def _generate_detections_v1(boxes: tf.Tensor,
pre_nms_top_k
=
pre_nms_top_k
,
pre_nms_top_k
=
pre_nms_top_k
,
pre_nms_score_threshold
=
pre_nms_score_threshold
,
pre_nms_score_threshold
=
pre_nms_score_threshold
,
nms_iou_threshold
=
nms_iou_threshold
,
nms_iou_threshold
=
nms_iou_threshold
,
max_num_detections
=
max_num_detections
)
max_num_detections
=
max_num_detections
,
soft_nms_sigma
=
soft_nms_sigma
)
nmsed_boxes
.
append
(
nmsed_boxes_i
)
nmsed_boxes
.
append
(
nmsed_boxes_i
)
nmsed_scores
.
append
(
nmsed_scores_i
)
nmsed_scores
.
append
(
nmsed_scores_i
)
nmsed_classes
.
append
(
nmsed_classes_i
)
nmsed_classes
.
append
(
nmsed_classes_i
)
...
@@ -126,7 +131,8 @@ def _generate_detections_per_image(
...
@@ -126,7 +131,8 @@ def _generate_detections_per_image(
pre_nms_top_k
:
int
=
5000
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
):
max_num_detections
:
int
=
100
,
soft_nms_sigma
:
Optional
[
float
]
=
None
):
"""Generates the final detections per image given the model outputs.
"""Generates the final detections per image given the model outputs.
Args:
Args:
...
@@ -149,6 +155,9 @@ def _generate_detections_per_image(
...
@@ -149,6 +155,9 @@ def _generate_detections_per_image(
boxes overlap too much with respect to IOU.
boxes overlap too much with respect to IOU.
max_num_detections: A `scalar` representing maximum number of boxes retained
max_num_detections: A `scalar` representing maximum number of boxes retained
over all classes.
over all classes.
soft_nms_sigma: A `float` representing the sigma parameter for Soft NMS.
When soft_nms_sigma=0.0, we fall back to standard NMS.
If set to None, `tf.image.non_max_suppression_padded` is called instead.
Returns:
Returns:
nms_boxes: A `float` tf.Tensor of shape `[max_num_detections, 4]`
nms_boxes: A `float` tf.Tensor of shape `[max_num_detections, 4]`
...
@@ -182,21 +191,38 @@ def _generate_detections_per_image(
...
@@ -182,21 +191,38 @@ def _generate_detections_per_image(
scores_i
,
k
=
tf
.
minimum
(
tf
.
shape
(
scores_i
)[
-
1
],
pre_nms_top_k
))
scores_i
,
k
=
tf
.
minimum
(
tf
.
shape
(
scores_i
)[
-
1
],
pre_nms_top_k
))
boxes_i
=
tf
.
gather
(
boxes_i
,
indices
)
boxes_i
=
tf
.
gather
(
boxes_i
,
indices
)
(
nmsed_indices_i
,
if
soft_nms_sigma
is
not
None
:
nmsed_num_valid_i
)
=
tf
.
image
.
non_max_suppression_padded
(
(
nmsed_indices_i
,
tf
.
cast
(
boxes_i
,
tf
.
float32
),
nmsed_scores_i
)
=
tf
.
image
.
non_max_suppression_with_scores
(
tf
.
cast
(
scores_i
,
tf
.
float32
),
tf
.
cast
(
boxes_i
,
tf
.
float32
),
max_num_detections
,
tf
.
cast
(
scores_i
,
tf
.
float32
),
iou_threshold
=
nms_iou_threshold
,
max_num_detections
,
score_threshold
=
pre_nms_score_threshold
,
iou_threshold
=
nms_iou_threshold
,
pad_to_max_output_size
=
True
,
score_threshold
=
pre_nms_score_threshold
,
name
=
'nms_detections_'
+
str
(
i
))
soft_nms_sigma
=
soft_nms_sigma
,
nmsed_boxes_i
=
tf
.
gather
(
boxes_i
,
nmsed_indices_i
)
name
=
'nms_detections_'
+
str
(
i
))
nmsed_scores_i
=
tf
.
gather
(
scores_i
,
nmsed_indices_i
)
nmsed_boxes_i
=
tf
.
gather
(
boxes_i
,
nmsed_indices_i
)
# Sets scores of invalid boxes to -1.
nmsed_boxes_i
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
nmsed_scores_i
=
tf
.
where
(
nmsed_boxes_i
,
max_num_detections
,
0.0
)
tf
.
less
(
tf
.
range
(
max_num_detections
),
[
nmsed_num_valid_i
]),
nmsed_scores_i
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
nmsed_scores_i
,
-
tf
.
ones_like
(
nmsed_scores_i
))
nmsed_scores_i
,
max_num_detections
,
-
1.0
)
else
:
(
nmsed_indices_i
,
nmsed_num_valid_i
)
=
tf
.
image
.
non_max_suppression_padded
(
tf
.
cast
(
boxes_i
,
tf
.
float32
),
tf
.
cast
(
scores_i
,
tf
.
float32
),
max_num_detections
,
iou_threshold
=
nms_iou_threshold
,
score_threshold
=
pre_nms_score_threshold
,
pad_to_max_output_size
=
True
,
name
=
'nms_detections_'
+
str
(
i
))
nmsed_boxes_i
=
tf
.
gather
(
boxes_i
,
nmsed_indices_i
)
nmsed_scores_i
=
tf
.
gather
(
scores_i
,
nmsed_indices_i
)
# Sets scores of invalid boxes to -1.
nmsed_scores_i
=
tf
.
where
(
tf
.
less
(
tf
.
range
(
max_num_detections
),
[
nmsed_num_valid_i
]),
nmsed_scores_i
,
-
tf
.
ones_like
(
nmsed_scores_i
))
nmsed_classes_i
=
tf
.
fill
([
max_num_detections
],
i
)
nmsed_classes_i
=
tf
.
fill
([
max_num_detections
],
i
)
nmsed_boxes
.
append
(
nmsed_boxes_i
)
nmsed_boxes
.
append
(
nmsed_boxes_i
)
nmsed_scores
.
append
(
nmsed_scores_i
)
nmsed_scores
.
append
(
nmsed_scores_i
)
...
@@ -207,6 +233,8 @@ def _generate_detections_per_image(
...
@@ -207,6 +233,8 @@ def _generate_detections_per_image(
att_i
=
att
[:,
min
(
num_classes_for_attr
-
1
,
i
)]
att_i
=
att
[:,
min
(
num_classes_for_attr
-
1
,
i
)]
att_i
=
tf
.
gather
(
att_i
,
indices
)
att_i
=
tf
.
gather
(
att_i
,
indices
)
nmsed_att_i
=
tf
.
gather
(
att_i
,
nmsed_indices_i
)
nmsed_att_i
=
tf
.
gather
(
att_i
,
nmsed_indices_i
)
nmsed_att_i
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
nmsed_att_i
,
max_num_detections
,
0.0
)
nmsed_attributes
[
att_name
].
append
(
nmsed_att_i
)
nmsed_attributes
[
att_name
].
append
(
nmsed_att_i
)
# Concats results from all classes and sort them.
# Concats results from all classes and sort them.
...
@@ -407,6 +435,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -407,6 +435,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
max_num_detections
:
int
=
100
,
max_num_detections
:
int
=
100
,
nms_version
:
str
=
'v2'
,
nms_version
:
str
=
'v2'
,
use_cpu_nms
:
bool
=
False
,
use_cpu_nms
:
bool
=
False
,
soft_nms_sigma
:
Optional
[
float
]
=
None
,
**
kwargs
):
**
kwargs
):
"""Initializes a detection generator.
"""Initializes a detection generator.
...
@@ -423,6 +452,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -423,6 +452,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
generate.
generate.
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
soft_nms_sigma: A `float` representing the sigma parameter for Soft NMS.
When soft_nms_sigma=0.0, we fall back to standard NMS.
**kwargs: Additional keyword arguments passed to Layer.
**kwargs: Additional keyword arguments passed to Layer.
"""
"""
self
.
_config_dict
=
{
self
.
_config_dict
=
{
...
@@ -433,6 +464,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -433,6 +464,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'nms_version'
:
nms_version
,
'nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
'soft_nms_sigma'
:
soft_nms_sigma
,
}
}
super
(
DetectionGenerator
,
self
).
__init__
(
**
kwargs
)
super
(
DetectionGenerator
,
self
).
__init__
(
**
kwargs
)
...
@@ -540,7 +572,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -540,7 +572,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold
=
self
pre_nms_score_threshold
=
self
.
_config_dict
[
'pre_nms_score_threshold'
],
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
],
soft_nms_sigma
=
self
.
_config_dict
[
'soft_nms_sigma'
]))
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_v2
(
_generate_detections_v2
(
...
@@ -585,6 +618,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -585,6 +618,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
max_num_detections
:
int
=
100
,
max_num_detections
:
int
=
100
,
nms_version
:
str
=
'v1'
,
nms_version
:
str
=
'v1'
,
use_cpu_nms
:
bool
=
False
,
use_cpu_nms
:
bool
=
False
,
soft_nms_sigma
:
Optional
[
float
]
=
None
,
**
kwargs
):
**
kwargs
):
"""Initializes a multi-level detection generator.
"""Initializes a multi-level detection generator.
...
@@ -601,6 +635,8 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -601,6 +635,8 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
generate.
generate.
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
soft_nms_sigma: A `float` representing the sigma parameter for Soft NMS.
When soft_nms_sigma=0.0, we fall back to standard NMS.
**kwargs: Additional keyword arguments passed to Layer.
**kwargs: Additional keyword arguments passed to Layer.
"""
"""
self
.
_config_dict
=
{
self
.
_config_dict
=
{
...
@@ -611,6 +647,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -611,6 +647,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'nms_version'
:
nms_version
,
'nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
'soft_nms_sigma'
:
soft_nms_sigma
,
}
}
super
(
MultilevelDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
super
(
MultilevelDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
...
@@ -778,7 +815,8 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -778,7 +815,8 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold
=
self
pre_nms_score_threshold
=
self
.
_config_dict
[
'pre_nms_score_threshold'
],
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
],
soft_nms_sigma
=
self
.
_config_dict
[
'soft_nms_sigma'
]))
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_v2
(
_generate_detections_v2
(
...
...
official/vision/beta/modeling/layers/detection_generator_test.py
View file @
33582301
...
@@ -44,9 +44,11 @@ class DetectionGeneratorTest(
...
@@ -44,9 +44,11 @@ class DetectionGeneratorTest(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
product
(
@
parameterized
.
product
(
nms_version
=
[
'batched'
,
'v1'
,
'v2'
],
use_cpu_nms
=
[
True
,
False
])
nms_version
=
[
'batched'
,
'v1'
,
'v2'
],
def
testDetectionsOutputShape
(
self
,
nms_version
,
use_cpu_nms
):
use_cpu_nms
=
[
True
,
False
],
max_num_detections
=
100
soft_nms_sigma
=
[
None
,
0.1
])
def
testDetectionsOutputShape
(
self
,
nms_version
,
use_cpu_nms
,
soft_nms_sigma
):
max_num_detections
=
10
num_classes
=
4
num_classes
=
4
pre_nms_top_k
=
5000
pre_nms_top_k
=
5000
pre_nms_score_threshold
=
0.01
pre_nms_score_threshold
=
0.01
...
@@ -59,6 +61,7 @@ class DetectionGeneratorTest(
...
@@ -59,6 +61,7 @@ class DetectionGeneratorTest(
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'nms_version'
:
nms_version
,
'nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
'soft_nms_sigma'
:
soft_nms_sigma
,
}
}
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
...
@@ -99,6 +102,7 @@ class DetectionGeneratorTest(
...
@@ -99,6 +102,7 @@ class DetectionGeneratorTest(
'max_num_detections'
:
10
,
'max_num_detections'
:
10
,
'nms_version'
:
'v2'
,
'nms_version'
:
'v2'
,
'use_cpu_nms'
:
False
,
'use_cpu_nms'
:
False
,
'soft_nms_sigma'
:
None
,
}
}
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
...
@@ -116,18 +120,20 @@ class MultilevelDetectionGeneratorTest(
...
@@ -116,18 +120,20 @@ class MultilevelDetectionGeneratorTest(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
'batched'
,
False
,
True
),
(
'batched'
,
False
,
True
,
None
),
(
'batched'
,
False
,
False
),
(
'batched'
,
False
,
False
,
None
),
(
'v2'
,
False
,
True
),
(
'v2'
,
False
,
True
,
None
),
(
'v2'
,
False
,
False
),
(
'v2'
,
False
,
False
,
None
),
(
'v1'
,
True
,
True
),
(
'v1'
,
True
,
True
,
0.0
),
(
'v1'
,
True
,
False
),
(
'v1'
,
True
,
False
,
0.1
),
(
'v1'
,
True
,
False
,
None
),
)
)
def
testDetectionsOutputShape
(
self
,
nms_version
,
has_att_heads
,
use_cpu_nms
):
def
testDetectionsOutputShape
(
self
,
nms_version
,
has_att_heads
,
use_cpu_nms
,
soft_nms_sigma
):
min_level
=
4
min_level
=
4
max_level
=
6
max_level
=
6
num_scales
=
2
num_scales
=
2
max_num_detections
=
10
0
max_num_detections
=
10
aspect_ratios
=
[
1.0
,
2.0
]
aspect_ratios
=
[
1.0
,
2.0
]
anchor_scale
=
2.0
anchor_scale
=
2.0
output_size
=
[
64
,
64
]
output_size
=
[
64
,
64
]
...
@@ -143,6 +149,7 @@ class MultilevelDetectionGeneratorTest(
...
@@ -143,6 +149,7 @@ class MultilevelDetectionGeneratorTest(
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'nms_version'
:
nms_version
,
'nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
'soft_nms_sigma'
:
soft_nms_sigma
,
}
}
input_anchor
=
anchor
.
build_anchor_generator
(
min_level
,
max_level
,
input_anchor
=
anchor
.
build_anchor_generator
(
min_level
,
max_level
,
...
@@ -224,6 +231,7 @@ class MultilevelDetectionGeneratorTest(
...
@@ -224,6 +231,7 @@ class MultilevelDetectionGeneratorTest(
'max_num_detections'
:
10
,
'max_num_detections'
:
10
,
'nms_version'
:
'v2'
,
'nms_version'
:
'v2'
,
'use_cpu_nms'
:
False
,
'use_cpu_nms'
:
False
,
'soft_nms_sigma'
:
None
,
}
}
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
**
kwargs
)
...
...
official/vision/beta/modeling/retinanet_model_test.py
View file @
33582301
...
@@ -148,9 +148,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -148,9 +148,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
training
=
[
True
,
False
],
training
=
[
True
,
False
],
has_att_heads
=
[
True
,
False
],
has_att_heads
=
[
True
,
False
],
output_intermediate_features
=
[
True
,
False
],
output_intermediate_features
=
[
True
,
False
],
soft_nms_sigma
=
[
None
,
0.0
,
0.1
],
))
))
def
test_forward
(
self
,
strategy
,
image_size
,
training
,
has_att_heads
,
def
test_forward
(
self
,
strategy
,
image_size
,
training
,
has_att_heads
,
output_intermediate_features
):
output_intermediate_features
,
soft_nms_sigma
):
"""Test for creation of a R50-FPN RetinaNet."""
"""Test for creation of a R50-FPN RetinaNet."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
num_classes
=
3
num_classes
=
3
...
@@ -193,7 +194,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -193,7 +194,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads
=
attribute_heads
,
attribute_heads
=
attribute_heads
,
num_anchors_per_location
=
num_anchors_per_location
)
num_anchors_per_location
=
num_anchors_per_location
)
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
max_num_detections
=
10
,
nms_version
=
'v1'
)
max_num_detections
=
10
,
nms_version
=
'v1'
,
use_cpu_nms
=
soft_nms_sigma
is
not
None
,
soft_nms_sigma
=
soft_nms_sigma
)
model
=
retinanet_model
.
RetinaNetModel
(
model
=
retinanet_model
.
RetinaNetModel
(
backbone
=
backbone
,
backbone
=
backbone
,
decoder
=
decoder
,
decoder
=
decoder
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment