Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dbe39272
Commit
dbe39272
authored
Sep 20, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
Sep 20, 2021
Browse files
Internal change
PiperOrigin-RevId: 397809846
parent
abc4fc08
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
63 additions
and
38 deletions
+63
-38
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+1
-1
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+1
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+2
-2
official/vision/beta/modeling/layers/detection_generator.py
official/vision/beta/modeling/layers/detection_generator.py
+42
-16
official/vision/beta/modeling/layers/detection_generator_test.py
...l/vision/beta/modeling/layers/detection_generator_test.py
+13
-14
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+1
-1
official/vision/beta/projects/deepmac_maskrcnn/serving/detection.py
...ision/beta/projects/deepmac_maskrcnn/serving/detection.py
+1
-1
official/vision/beta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
...ta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
+1
-1
official/vision/beta/serving/detection_test.py
official/vision/beta/serving/detection_test.py
+1
-1
No files found.
official/vision/beta/configs/maskrcnn.py
View file @
dbe39272
...
@@ -131,7 +131,7 @@ class DetectionGenerator(hyperparams.Config):
...
@@ -131,7 +131,7 @@ class DetectionGenerator(hyperparams.Config):
pre_nms_score_threshold
:
float
=
0.05
pre_nms_score_threshold
:
float
=
0.05
nms_iou_threshold
:
float
=
0.5
nms_iou_threshold
:
float
=
0.5
max_num_detections
:
int
=
100
max_num_detections
:
int
=
100
use_batched_nms
:
bool
=
False
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`
use_cpu_nms
:
bool
=
False
use_cpu_nms
:
bool
=
False
...
...
official/vision/beta/configs/retinanet.py
View file @
dbe39272
...
@@ -112,7 +112,7 @@ class DetectionGenerator(hyperparams.Config):
...
@@ -112,7 +112,7 @@ class DetectionGenerator(hyperparams.Config):
pre_nms_score_threshold
:
float
=
0.05
pre_nms_score_threshold
:
float
=
0.05
nms_iou_threshold
:
float
=
0.5
nms_iou_threshold
:
float
=
0.5
max_num_detections
:
int
=
100
max_num_detections
:
int
=
100
use_batched_nms
:
bool
=
False
nms_version
:
str
=
'v2'
# `v2`, `v1`, `batched`.
use_cpu_nms
:
bool
=
False
use_cpu_nms
:
bool
=
False
...
...
official/vision/beta/modeling/factory.py
View file @
dbe39272
...
@@ -197,7 +197,7 @@ def build_maskrcnn(
...
@@ -197,7 +197,7 @@ def build_maskrcnn(
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
max_num_detections
=
generator_config
.
max_num_detections
,
max_num_detections
=
generator_config
.
max_num_detections
,
use_batched_nms
=
generator_config
.
use_batched_nms
,
nms_version
=
generator_config
.
nms_version
,
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
if
model_config
.
include_mask
:
if
model_config
.
include_mask
:
...
@@ -300,7 +300,7 @@ def build_retinanet(
...
@@ -300,7 +300,7 @@ def build_retinanet(
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
max_num_detections
=
generator_config
.
max_num_detections
,
max_num_detections
=
generator_config
.
max_num_detections
,
use_batched_nms
=
generator_config
.
use_batched_nms
,
nms_version
=
generator_config
.
nms_version
,
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
use_cpu_nms
=
generator_config
.
use_cpu_nms
)
model
=
retinanet_model
.
RetinaNetModel
(
model
=
retinanet_model
.
RetinaNetModel
(
...
...
official/vision/beta/modeling/layers/detection_generator.py
View file @
dbe39272
...
@@ -404,7 +404,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -404,7 +404,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold
:
float
=
0.05
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
,
max_num_detections
:
int
=
100
,
use_batched_nms
:
bool
=
False
,
nms_version
:
str
=
'v2'
,
use_cpu_nms
:
bool
=
False
,
use_cpu_nms
:
bool
=
False
,
**
kwargs
):
**
kwargs
):
"""Initializes a detection generator.
"""Initializes a detection generator.
...
@@ -420,8 +420,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -420,8 +420,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
max_num_detections: An `int` of the final number of total detections to
max_num_detections: An `int` of the final number of total detections to
generate.
generate.
use_batched_nms: A `bool` of whether or not use
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
`tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer.
**kwargs: Additional keyword arguments passed to Layer.
"""
"""
...
@@ -431,7 +430,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -431,7 +430,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'nms_iou_threshold'
:
nms_iou_threshold
,
'nms_iou_threshold'
:
nms_iou_threshold
,
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'
use_batched_nms'
:
use_batched_nms
,
'
nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
}
}
super
(
DetectionGenerator
,
self
).
__init__
(
**
kwargs
)
super
(
DetectionGenerator
,
self
).
__init__
(
**
kwargs
)
...
@@ -524,14 +523,14 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -524,14 +523,14 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_context
=
contextlib
.
nullcontext
()
nms_context
=
contextlib
.
nullcontext
()
with
nms_context
:
with
nms_context
:
if
self
.
_config_dict
[
'
use_
batched
_nms'
]
:
if
self
.
_config_dict
[
'
nms_version'
]
==
'
batched
'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_batched
(
_generate_detections_batched
(
decoded_boxes
,
box_scores
,
decoded_boxes
,
box_scores
,
self
.
_config_dict
[
'pre_nms_score_threshold'
],
self
.
_config_dict
[
'pre_nms_score_threshold'
],
self
.
_config_dict
[
'nms_iou_threshold'
],
self
.
_config_dict
[
'nms_iou_threshold'
],
self
.
_config_dict
[
'max_num_detections'
]))
self
.
_config_dict
[
'max_num_detections'
]))
el
se
:
el
if
self
.
_config_dict
[
'nms_version'
]
==
'v1'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
_
)
=
(
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
_
)
=
(
_generate_detections_v1
(
_generate_detections_v1
(
decoded_boxes
,
decoded_boxes
,
...
@@ -541,6 +540,19 @@ class DetectionGenerator(tf.keras.layers.Layer):
...
@@ -541,6 +540,19 @@ class DetectionGenerator(tf.keras.layers.Layer):
.
_config_dict
[
'pre_nms_score_threshold'
],
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_v2
(
decoded_boxes
,
box_scores
,
pre_nms_top_k
=
self
.
_config_dict
[
'pre_nms_top_k'
],
pre_nms_score_threshold
=
self
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
else
:
raise
ValueError
(
'NMS version {} not supported.'
.
format
(
self
.
_config_dict
[
'nms_version'
]))
# Adds 1 to offset the background class which has index 0.
# Adds 1 to offset the background class which has index 0.
nmsed_classes
+=
1
nmsed_classes
+=
1
...
@@ -570,7 +582,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -570,7 +582,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
pre_nms_score_threshold
:
float
=
0.05
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
,
max_num_detections
:
int
=
100
,
use_batched_nms
:
bool
=
False
,
nms_version
:
str
=
'v1'
,
use_cpu_nms
:
bool
=
False
,
use_cpu_nms
:
bool
=
False
,
**
kwargs
):
**
kwargs
):
"""Initializes a multi-level detection generator.
"""Initializes a multi-level detection generator.
...
@@ -586,8 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -586,8 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
max_num_detections: An `int` of the final number of total detections to
max_num_detections: An `int` of the final number of total detections to
generate.
generate.
use_batched_nms: A `bool` of whether or not use
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
`tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer.
**kwargs: Additional keyword arguments passed to Layer.
"""
"""
...
@@ -597,7 +608,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -597,7 +608,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'nms_iou_threshold'
:
nms_iou_threshold
,
'nms_iou_threshold'
:
nms_iou_threshold
,
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'
use_batched_nms'
:
use_batched_nms
,
'
nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
}
}
super
(
MultilevelDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
super
(
MultilevelDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
...
@@ -731,11 +742,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -731,11 +742,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_context
=
contextlib
.
nullcontext
()
nms_context
=
contextlib
.
nullcontext
()
with
nms_context
:
with
nms_context
:
if
self
.
_config_dict
[
'use_batched_nms'
]
:
if
raw_attributes
and
(
self
.
_config_dict
[
'nms_version'
]
!=
'v1'
)
:
if
raw_attributes
:
raise
ValueError
(
raise
ValueError
(
'Attribute learning is only supported for NMSv1 but NMS {} is used.'
'Attribute learning is not supported for batched NMS.'
)
.
format
(
self
.
_config_dict
[
'nms_version'
])
)
if
self
.
_config_dict
[
'nms_version'
]
==
'batched'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_batched
(
_generate_detections_batched
(
boxes
,
scores
,
self
.
_config_dict
[
'pre_nms_score_threshold'
],
boxes
,
scores
,
self
.
_config_dict
[
'pre_nms_score_threshold'
],
...
@@ -743,7 +754,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -743,7 +754,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
self
.
_config_dict
[
'max_num_detections'
]))
self
.
_config_dict
[
'max_num_detections'
]))
# Set `nmsed_attributes` to None for batched NMS.
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes
=
{}
nmsed_attributes
=
{}
el
se
:
el
if
self
.
_config_dict
[
'nms_version'
]
==
'v1'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
nmsed_attributes
)
=
(
nmsed_attributes
)
=
(
_generate_detections_v1
(
_generate_detections_v1
(
...
@@ -755,6 +766,21 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
...
@@ -755,6 +766,21 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
.
_config_dict
[
'pre_nms_score_threshold'
],
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
elif
self
.
_config_dict
[
'nms_version'
]
==
'v2'
:
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
(
_generate_detections_v2
(
boxes
,
scores
,
pre_nms_top_k
=
self
.
_config_dict
[
'pre_nms_top_k'
],
pre_nms_score_threshold
=
self
.
_config_dict
[
'pre_nms_score_threshold'
],
nms_iou_threshold
=
self
.
_config_dict
[
'nms_iou_threshold'
],
max_num_detections
=
self
.
_config_dict
[
'max_num_detections'
]))
# Set `nmsed_attributes` to None for v2.
nmsed_attributes
=
{}
else
:
raise
ValueError
(
'NMS version {} not supported.'
.
format
(
self
.
_config_dict
[
'nms_version'
]))
# Adds 1 to offset the background class which has index 0.
# Adds 1 to offset the background class which has index 0.
nmsed_classes
+=
1
nmsed_classes
+=
1
...
...
official/vision/beta/modeling/layers/detection_generator_test.py
View file @
dbe39272
...
@@ -44,8 +44,8 @@ class DetectionGeneratorTest(
...
@@ -44,8 +44,8 @@ class DetectionGeneratorTest(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
product
(
@
parameterized
.
product
(
use_batched_nms
=
[
True
,
False
],
use_cpu_nms
=
[
True
,
False
])
nms_version
=
[
'batched'
,
'v1'
,
'v2'
],
use_cpu_nms
=
[
True
,
False
])
def
testDetectionsOutputShape
(
self
,
use_batched_nms
,
use_cpu_nms
):
def
testDetectionsOutputShape
(
self
,
nms_version
,
use_cpu_nms
):
max_num_detections
=
100
max_num_detections
=
100
num_classes
=
4
num_classes
=
4
pre_nms_top_k
=
5000
pre_nms_top_k
=
5000
...
@@ -57,7 +57,7 @@ class DetectionGeneratorTest(
...
@@ -57,7 +57,7 @@ class DetectionGeneratorTest(
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'nms_iou_threshold'
:
0.5
,
'nms_iou_threshold'
:
0.5
,
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'
use_batched_nms'
:
use_batched_nms
,
'
nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
}
}
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
...
@@ -97,7 +97,7 @@ class DetectionGeneratorTest(
...
@@ -97,7 +97,7 @@ class DetectionGeneratorTest(
'pre_nms_score_threshold'
:
0.1
,
'pre_nms_score_threshold'
:
0.1
,
'nms_iou_threshold'
:
0.5
,
'nms_iou_threshold'
:
0.5
,
'max_num_detections'
:
10
,
'max_num_detections'
:
10
,
'
use_batched_nms'
:
False
,
'
nms_version'
:
'v2'
,
'use_cpu_nms'
:
False
,
'use_cpu_nms'
:
False
,
}
}
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
DetectionGenerator
(
**
kwargs
)
...
@@ -116,15 +116,14 @@ class MultilevelDetectionGeneratorTest(
...
@@ -116,15 +116,14 @@ class MultilevelDetectionGeneratorTest(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
True
,
False
,
True
),
(
'batched'
,
False
,
True
),
(
True
,
False
,
False
),
(
'batched'
,
False
,
False
),
(
False
,
False
,
True
),
(
'v2'
,
False
,
True
),
(
False
,
False
,
False
),
(
'v2'
,
False
,
False
),
(
False
,
True
,
True
),
(
'v1'
,
True
,
True
),
(
False
,
True
,
False
),
(
'v1'
,
True
,
False
),
)
)
def
testDetectionsOutputShape
(
self
,
use_batched_nms
,
has_att_heads
,
def
testDetectionsOutputShape
(
self
,
nms_version
,
has_att_heads
,
use_cpu_nms
):
use_cpu_nms
):
min_level
=
4
min_level
=
4
max_level
=
6
max_level
=
6
num_scales
=
2
num_scales
=
2
...
@@ -142,7 +141,7 @@ class MultilevelDetectionGeneratorTest(
...
@@ -142,7 +141,7 @@ class MultilevelDetectionGeneratorTest(
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'pre_nms_score_threshold'
:
pre_nms_score_threshold
,
'nms_iou_threshold'
:
0.5
,
'nms_iou_threshold'
:
0.5
,
'max_num_detections'
:
max_num_detections
,
'max_num_detections'
:
max_num_detections
,
'
use_batched_nms'
:
use_batched_nms
,
'
nms_version'
:
nms_version
,
'use_cpu_nms'
:
use_cpu_nms
,
'use_cpu_nms'
:
use_cpu_nms
,
}
}
...
@@ -223,7 +222,7 @@ class MultilevelDetectionGeneratorTest(
...
@@ -223,7 +222,7 @@ class MultilevelDetectionGeneratorTest(
'pre_nms_score_threshold'
:
0.1
,
'pre_nms_score_threshold'
:
0.1
,
'nms_iou_threshold'
:
0.5
,
'nms_iou_threshold'
:
0.5
,
'max_num_detections'
:
10
,
'max_num_detections'
:
10
,
'
use_batched_nms'
:
False
,
'
nms_version'
:
'v2'
,
'use_cpu_nms'
:
False
,
'use_cpu_nms'
:
False
,
}
}
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
**
kwargs
)
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
**
kwargs
)
...
...
official/vision/beta/modeling/retinanet_model_test.py
View file @
dbe39272
...
@@ -193,7 +193,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -193,7 +193,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads
=
attribute_heads
,
attribute_heads
=
attribute_heads
,
num_anchors_per_location
=
num_anchors_per_location
)
num_anchors_per_location
=
num_anchors_per_location
)
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
generator
=
detection_generator
.
MultilevelDetectionGenerator
(
max_num_detections
=
10
)
max_num_detections
=
10
,
nms_version
=
'v1'
)
model
=
retinanet_model
.
RetinaNetModel
(
model
=
retinanet_model
.
RetinaNetModel
(
backbone
=
backbone
,
backbone
=
backbone
,
decoder
=
decoder
,
decoder
=
decoder
,
...
...
official/vision/beta/projects/deepmac_maskrcnn/serving/detection.py
View file @
dbe39272
...
@@ -28,7 +28,7 @@ class DetectionModule(detection.DetectionModule):
...
@@ -28,7 +28,7 @@ class DetectionModule(detection.DetectionModule):
if
self
.
_batch_size
is
None
:
if
self
.
_batch_size
is
None
:
ValueError
(
"batch_size can't be None for detection models"
)
ValueError
(
"batch_size can't be None for detection models"
)
if
not
self
.
params
.
task
.
model
.
detection_generator
.
use_
batched
_nms
:
if
self
.
params
.
task
.
model
.
detection_generator
.
nms_version
!=
'
batched
'
:
ValueError
(
'Only batched_nms is supported.'
)
ValueError
(
'Only batched_nms is supported.'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
self
.
_input_image_size
+
[
3
])
...
...
official/vision/beta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
View file @
dbe39272
...
@@ -120,7 +120,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
...
@@ -120,7 +120,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
max_num_detections
=
generator_config
.
max_num_detections
,
max_num_detections
=
generator_config
.
max_num_detections
,
use_batched_nms
=
generator_config
.
use_batched_nms
)
nms_version
=
generator_config
.
nms_version
)
if
model_config
.
include_mask
:
if
model_config
.
include_mask
:
mask_head
=
deep_instance_heads
.
DeepMaskHead
(
mask_head
=
deep_instance_heads
.
DeepMaskHead
(
...
...
official/vision/beta/serving/detection_test.py
View file @
dbe39272
...
@@ -33,7 +33,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -33,7 +33,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def
_get_detection_module
(
self
,
experiment_name
):
def
_get_detection_module
(
self
,
experiment_name
):
params
=
exp_factory
.
get_exp_config
(
experiment_name
)
params
=
exp_factory
.
get_exp_config
(
experiment_name
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
params
.
task
.
model
.
detection_generator
.
use_batched_nms
=
True
params
.
task
.
model
.
detection_generator
.
nms_version
=
'batched'
detection_module
=
detection
.
DetectionModule
(
detection_module
=
detection
.
DetectionModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
])
params
,
batch_size
=
1
,
input_image_size
=
[
640
,
640
])
return
detection_module
return
detection_module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment