Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
...@@ -42,6 +42,7 @@ class FPN(tf.keras.Model): ...@@ -42,6 +42,7 @@ class FPN(tf.keras.Model):
min_level: int = 3, min_level: int = 3,
max_level: int = 7, max_level: int = 7,
num_filters: int = 256, num_filters: int = 256,
fusion_type: str = 'sum',
use_separable_conv: bool = False, use_separable_conv: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
...@@ -59,6 +60,8 @@ class FPN(tf.keras.Model): ...@@ -59,6 +60,8 @@ class FPN(tf.keras.Model):
min_level: An `int` of minimum level in FPN output feature maps. min_level: An `int` of minimum level in FPN output feature maps.
max_level: An `int` of maximum level in FPN output feature maps. max_level: An `int` of maximum level in FPN output feature maps.
num_filters: An `int` number of filters in FPN layers. num_filters: An `int` number of filters in FPN layers.
fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
concat for feature fusion.
use_separable_conv: A `bool`. If True use separable convolution for use_separable_conv: A `bool`. If True use separable convolution for
convolution in FPN layers. convolution in FPN layers.
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
...@@ -77,6 +80,7 @@ class FPN(tf.keras.Model): ...@@ -77,6 +80,7 @@ class FPN(tf.keras.Model):
'min_level': min_level, 'min_level': min_level,
'max_level': max_level, 'max_level': max_level,
'num_filters': num_filters, 'num_filters': num_filters,
'fusion_type': fusion_type,
'use_separable_conv': use_separable_conv, 'use_separable_conv': use_separable_conv,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
...@@ -122,8 +126,16 @@ class FPN(tf.keras.Model): ...@@ -122,8 +126,16 @@ class FPN(tf.keras.Model):
# Build top-down path. # Build top-down path.
feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]} feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
for level in range(backbone_max_level - 1, min_level - 1, -1): for level in range(backbone_max_level - 1, min_level - 1, -1):
feats[str(level)] = spatial_transform_ops.nearest_upsampling( feat_a = spatial_transform_ops.nearest_upsampling(
feats[str(level + 1)], 2) + feats_lateral[str(level)] feats[str(level + 1)], 2)
feat_b = feats_lateral[str(level)]
if fusion_type == 'sum':
feats[str(level)] = feat_a + feat_b
elif fusion_type == 'concat':
feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
else:
raise ValueError('Fusion type {} not supported.'.format(fusion_type))
# TODO(xianzhi): consider to remove bias in conv2d. # TODO(xianzhi): consider to remove bias in conv2d.
# Build post-hoc 3x3 convolution kernel. # Build post-hoc 3x3 convolution kernel.
...@@ -224,6 +236,7 @@ def build_fpn_decoder( ...@@ -224,6 +236,7 @@ def build_fpn_decoder(
min_level=model_config.min_level, min_level=model_config.min_level,
max_level=model_config.max_level, max_level=model_config.max_level,
num_filters=decoder_cfg.num_filters, num_filters=decoder_cfg.num_filters,
fusion_type=decoder_cfg.fusion_type,
use_separable_conv=decoder_cfg.use_separable_conv, use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -27,11 +27,11 @@ from official.vision.beta.modeling.decoders import fpn ...@@ -27,11 +27,11 @@ from official.vision.beta.modeling.decoders import fpn
class FPNTest(parameterized.TestCase, tf.test.TestCase): class FPNTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(256, 3, 7, False), (256, 3, 7, False, 'sum'),
(256, 3, 7, True), (256, 3, 7, True, 'concat'),
) )
def test_network_creation(self, input_size, min_level, max_level, def test_network_creation(self, input_size, min_level, max_level,
use_separable_conv): use_separable_conv, fusion_type):
"""Test creation of FPN.""" """Test creation of FPN."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
...@@ -42,6 +42,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,6 +42,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
fusion_type=fusion_type,
use_separable_conv=use_separable_conv) use_separable_conv=use_separable_conv)
endpoints = backbone(inputs) endpoints = backbone(inputs)
...@@ -87,6 +88,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -87,6 +88,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
min_level=3, min_level=3,
max_level=7, max_level=7,
num_filters=256, num_filters=256,
fusion_type='sum',
use_separable_conv=False, use_separable_conv=False,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
......
...@@ -76,7 +76,7 @@ def build_maskrcnn( ...@@ -76,7 +76,7 @@ def build_maskrcnn(
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config, norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
backbone(tf.keras.Input(input_specs.shape[1:])) backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
...@@ -119,6 +119,13 @@ def build_maskrcnn( ...@@ -119,6 +119,13 @@ def build_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
name='detection_head') name='detection_head')
# Build backbone, decoder and region proposal network:
if decoder:
decoder_features = decoder(backbone_features)
rpn_head(decoder_features)
if roi_sampler_config.cascade_iou_thresholds: if roi_sampler_config.cascade_iou_thresholds:
detection_head_cascade = [detection_head] detection_head_cascade = [detection_head]
for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)): for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)):
...@@ -189,7 +196,8 @@ def build_maskrcnn( ...@@ -189,7 +196,8 @@ def build_maskrcnn(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms) use_batched_nms=generator_config.use_batched_nms,
use_cpu_nms=generator_config.use_cpu_nms)
if model_config.include_mask: if model_config.include_mask:
mask_head = instance_heads.MaskHead( mask_head = instance_heads.MaskHead(
...@@ -286,7 +294,8 @@ def build_retinanet( ...@@ -286,7 +294,8 @@ def build_retinanet(
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections, max_num_detections=generator_config.max_num_detections,
use_batched_nms=generator_config.use_batched_nms) use_batched_nms=generator_config.use_batched_nms,
use_cpu_nms=generator_config.use_cpu_nms)
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
backbone, backbone,
...@@ -326,6 +335,7 @@ def build_segmentation_model( ...@@ -326,6 +335,7 @@ def build_segmentation_model(
num_convs=head_config.num_convs, num_convs=head_config.num_convs,
prediction_kernel_size=head_config.prediction_kernel_size, prediction_kernel_size=head_config.prediction_kernel_size,
num_filters=head_config.num_filters, num_filters=head_config.num_filters,
use_depthwise_convolution=head_config.use_depthwise_convolution,
upsample_factor=head_config.upsample_factor, upsample_factor=head_config.upsample_factor,
feature_fusion=head_config.feature_fusion, feature_fusion=head_config.feature_fusion,
low_level=head_config.low_level, low_level=head_config.low_level,
......
...@@ -31,6 +31,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -31,6 +31,7 @@ class SegmentationHead(tf.keras.layers.Layer):
level: Union[int, str], level: Union[int, str],
num_convs: int = 2, num_convs: int = 2,
num_filters: int = 256, num_filters: int = 256,
use_depthwise_convolution: bool = False,
prediction_kernel_size: int = 1, prediction_kernel_size: int = 1,
upsample_factor: int = 1, upsample_factor: int = 1,
feature_fusion: Optional[str] = None, feature_fusion: Optional[str] = None,
...@@ -53,6 +54,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -53,6 +54,8 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer. prediction layer.
num_filters: An `int` number to specify the number of filters used. num_filters: An `int` number to specify the number of filters used.
Default is 256. Default is 256.
use_depthwise_convolution: A bool to specify if use depthwise separable
convolutions.
prediction_kernel_size: An `int` number to specify the kernel size of the prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer. prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to upsample_factor: An `int` number to specify the upsampling factor to
...@@ -84,6 +87,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -84,6 +87,7 @@ class SegmentationHead(tf.keras.layers.Layer):
'level': level, 'level': level,
'num_convs': num_convs, 'num_convs': num_convs,
'num_filters': num_filters, 'num_filters': num_filters,
'use_depthwise_convolution': use_depthwise_convolution,
'prediction_kernel_size': prediction_kernel_size, 'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor, 'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion, 'feature_fusion': feature_fusion,
...@@ -104,12 +108,14 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -104,12 +108,14 @@ class SegmentationHead(tf.keras.layers.Layer):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the segmentation head.""" """Creates the variables of the segmentation head."""
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
random_initializer = tf.keras.initializers.RandomNormal(stddev=0.01)
conv_op = tf.keras.layers.Conv2D conv_op = tf.keras.layers.Conv2D
conv_kwargs = { conv_kwargs = {
'kernel_size': 3, 'kernel_size': 3 if not use_depthwise_convolution else 1,
'padding': 'same', 'padding': 'same',
'use_bias': False, 'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01), 'kernel_initializer': random_initializer,
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
} }
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
...@@ -139,6 +145,16 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -139,6 +145,16 @@ class SegmentationHead(tf.keras.layers.Layer):
self._convs = [] self._convs = []
self._norms = [] self._norms = []
for i in range(self._config_dict['num_convs']): for i in range(self._config_dict['num_convs']):
if use_depthwise_convolution:
self._convs.append(
tf.keras.layers.DepthwiseConv2D(
name='segmentation_head_depthwise_conv_{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=random_initializer,
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1))
conv_name = 'segmentation_head_conv_{}'.format(i) conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append( self._convs.append(
conv_op( conv_op(
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of generators to generate the final detections.""" """Contains definitions of generators to generate the final detections."""
import contextlib
from typing import List, Optional, Mapping from typing import List, Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -404,6 +405,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -404,6 +405,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, use_batched_nms: bool = False,
use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a detection generator. """Initializes a detection generator.
...@@ -420,6 +422,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -420,6 +422,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
generate. generate.
use_batched_nms: A `bool` of whether or not use use_batched_nms: A `bool` of whether or not use
`tf.image.combined_non_max_suppression`. `tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
...@@ -429,6 +432,7 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -429,6 +432,7 @@ class DetectionGenerator(tf.keras.layers.Layer):
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
super(DetectionGenerator, self).__init__(**kwargs) super(DetectionGenerator, self).__init__(**kwargs)
...@@ -513,23 +517,30 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -513,23 +517,30 @@ class DetectionGenerator(tf.keras.layers.Layer):
'decoded_box_scores': box_scores, 'decoded_box_scores': box_scores,
} }
if self._config_dict['use_batched_nms']: # Optionally force the NMS be run on CPU.
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( if self._config_dict['use_cpu_nms']:
_generate_detections_batched( nms_context = tf.device('cpu:0')
decoded_boxes,
box_scores,
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
else: else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( nms_context = contextlib.nullcontext()
_generate_detections_v2(
decoded_boxes, with nms_context:
box_scores, if self._config_dict['use_batched_nms']:
self._config_dict['pre_nms_top_k'], (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
self._config_dict['pre_nms_score_threshold'], _generate_detections_batched(
self._config_dict['nms_iou_threshold'], decoded_boxes, box_scores,
self._config_dict['max_num_detections'])) self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
_generate_detections_v1(
decoded_boxes,
box_scores,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
...@@ -560,6 +571,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -560,6 +571,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
max_num_detections: int = 100, max_num_detections: int = 100,
use_batched_nms: bool = False, use_batched_nms: bool = False,
use_cpu_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a multi-level detection generator. """Initializes a multi-level detection generator.
...@@ -576,6 +588,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -576,6 +588,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
generate. generate.
use_batched_nms: A `bool` of whether or not use use_batched_nms: A `bool` of whether or not use
`tf.image.combined_non_max_suppression`. `tf.image.combined_non_max_suppression`.
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
...@@ -585,6 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -585,6 +598,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'nms_iou_threshold': nms_iou_threshold, 'nms_iou_threshold': nms_iou_threshold,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
super(MultilevelDetectionGenerator, self).__init__(**kwargs) super(MultilevelDetectionGenerator, self).__init__(**kwargs)
...@@ -710,39 +724,38 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -710,39 +724,38 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'decoded_box_attributes': attributes, 'decoded_box_attributes': attributes,
} }
if self._config_dict['use_batched_nms']: # Optionally force the NMS to run on CPU.
if raw_attributes: if self._config_dict['use_cpu_nms']:
raise ValueError('Attribute learning is not supported for batched NMS.') nms_context = tf.device('cpu:0')
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_batched(
boxes,
scores,
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {}
else: else:
if raw_attributes: nms_context = contextlib.nullcontext()
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes = (
_generate_detections_v1( with nms_context:
boxes, if self._config_dict['use_batched_nms']:
scores, if raw_attributes:
attributes=attributes if raw_attributes else None, raise ValueError(
pre_nms_top_k=self._config_dict['pre_nms_top_k'], 'Attribute learning is not supported for batched NMS.')
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'], (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
nms_iou_threshold=self._config_dict['nms_iou_threshold'], _generate_detections_batched(
max_num_detections=self._config_dict['max_num_detections'])) boxes, scores, self._config_dict['pre_nms_score_threshold'],
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
boxes, scores, self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'], self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections'])) self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {} nmsed_attributes = {}
else:
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
nmsed_attributes) = (
_generate_detections_v1(
boxes,
scores,
attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
# Adds 1 to offset the background class which has index 0. # Adds 1 to offset the background class which has index 0.
nmsed_classes += 1 nmsed_classes += 1
......
...@@ -43,11 +43,9 @@ class SelectTopKScoresTest(tf.test.TestCase): ...@@ -43,11 +43,9 @@ class SelectTopKScoresTest(tf.test.TestCase):
class DetectionGeneratorTest( class DetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.product(
(True), use_batched_nms=[True, False], use_cpu_nms=[True, False])
(False), def testDetectionsOutputShape(self, use_batched_nms, use_cpu_nms):
)
def testDetectionsOutputShape(self, use_batched_nms):
max_num_detections = 100 max_num_detections = 100
num_classes = 4 num_classes = 4
pre_nms_top_k = 5000 pre_nms_top_k = 5000
...@@ -60,6 +58,7 @@ class DetectionGeneratorTest( ...@@ -60,6 +58,7 @@ class DetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -99,6 +98,7 @@ class DetectionGeneratorTest( ...@@ -99,6 +98,7 @@ class DetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'use_batched_nms': False,
'use_cpu_nms': False,
} }
generator = detection_generator.DetectionGenerator(**kwargs) generator = detection_generator.DetectionGenerator(**kwargs)
...@@ -116,16 +116,20 @@ class MultilevelDetectionGeneratorTest( ...@@ -116,16 +116,20 @@ class MultilevelDetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase): parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(True, False), (True, False, True),
(False, False), (True, False, False),
(False, True), (False, False, True),
(False, False, False),
(False, True, True),
(False, True, False),
) )
def testDetectionsOutputShape(self, use_batched_nms, has_att_heads): def testDetectionsOutputShape(self, use_batched_nms, has_att_heads,
use_cpu_nms):
min_level = 4 min_level = 4
max_level = 6 max_level = 6
num_scales = 2 num_scales = 2
max_num_detections = 100 max_num_detections = 100
aspect_ratios = [1.0, 2.0,] aspect_ratios = [1.0, 2.0]
anchor_scale = 2.0 anchor_scale = 2.0
output_size = [64, 64] output_size = [64, 64]
num_classes = 4 num_classes = 4
...@@ -139,6 +143,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -139,6 +143,7 @@ class MultilevelDetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': max_num_detections, 'max_num_detections': max_num_detections,
'use_batched_nms': use_batched_nms, 'use_batched_nms': use_batched_nms,
'use_cpu_nms': use_cpu_nms,
} }
input_anchor = anchor.build_anchor_generator(min_level, max_level, input_anchor = anchor.build_anchor_generator(min_level, max_level,
...@@ -219,6 +224,7 @@ class MultilevelDetectionGeneratorTest( ...@@ -219,6 +224,7 @@ class MultilevelDetectionGeneratorTest(
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'max_num_detections': 10, 'max_num_detections': 10,
'use_batched_nms': False, 'use_batched_nms': False,
'use_cpu_nms': False,
} }
generator = detection_generator.MultilevelDetectionGenerator(**kwargs) generator = detection_generator.MultilevelDetectionGenerator(**kwargs)
......
...@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
bn_trainable=True,
**kwargs): **kwargs):
"""Initializes a residual block with BN after convolutions. """Initializes a residual block with BN after convolutions.
...@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ResidualBlock, self).__init__(**kwargs) super(ResidualBlock, self).__init__(**kwargs)
...@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer):
else: else:
self._bn_axis = 1 self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
self._bn_trainable = bn_trainable
def build(self, input_shape): def build(self, input_shape):
if self._use_projection: if self._use_projection:
...@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv2 = tf.keras.layers.Conv2D( self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation( self._squeeze_excitation = nn_layers.SqueezeExcitation(
...@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer):
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
} }
base_config = super(ResidualBlock, self).get_config() base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
bn_trainable=True,
**kwargs): **kwargs):
"""Initializes a standard bottleneck block with BN after convolutions. """Initializes a standard bottleneck block with BN after convolutions.
...@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(BottleneckBlock, self).__init__(**kwargs) super(BottleneckBlock, self).__init__(**kwargs)
...@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._bn_axis = -1 self._bn_axis = -1
else: else:
self._bn_axis = 1 self._bn_axis = 1
self._bn_trainable = bn_trainable
def build(self, input_shape): def build(self, input_shape):
if self._use_projection: if self._use_projection:
...@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation1 = tf_utils.get_activation( self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation2 = tf_utils.get_activation( self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm3 = self._norm( self._norm3 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation3 = tf_utils.get_activation( self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
} }
base_config = super(BottleneckBlock, self).get_config() base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -165,7 +165,8 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -165,7 +165,8 @@ class SqueezeExcitation(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
num_reduced_filters = make_divisible( num_reduced_filters = make_divisible(
self._in_filters * self._se_ratio, divisor=self._divisible_by) max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by)
self._se_reduce = tf.keras.layers.Conv2D( self._se_reduce = tf.keras.layers.Conv2D(
filters=num_reduced_filters, filters=num_reduced_filters,
...@@ -424,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -424,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
self._rezero = Scale(initializer=initializer, name='rezero') self._rezero = Scale(initializer=initializer, name='rezero')
state_prefix = state_prefix if state_prefix is not None else '' state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._frame_count_name = f'{state_prefix}/pos_enc_frame_count' self._frame_count_name = f'{state_prefix}_pos_enc_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -522,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -522,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`. inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). Expected keys layer, will overwrite the contents of the buffer(s). Expected keys
include `state_prefix + '/pos_enc_frame_count'`. include `state_prefix + '_pos_enc_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise. states. Returns just the output tensor otherwise.
...@@ -586,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -586,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else '' state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/pool_buffer' self._state_name = f'{state_prefix}_pool_buffer'
self._frame_count_name = f'{state_prefix}/pool_frame_count' self._frame_count_name = f'{state_prefix}_pool_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -610,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -610,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`. inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/pool_buffer'` and Expected keys include `state_prefix + '__pool_buffer'` and
`state_prefix + '/pool_frame_count'`. `state_prefix + '__pool_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise. states. Returns just the output tensor otherwise.
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Mask R-CNN model.""" """R-CNN(-RS) models."""
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -24,7 +24,7 @@ from official.vision.beta.ops import box_ops ...@@ -24,7 +24,7 @@ from official.vision.beta.ops import box_ops
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class MaskRCNNModel(tf.keras.Model): class MaskRCNNModel(tf.keras.Model):
"""The Mask R-CNN model.""" """The Mask R-CNN(-RS) and Cascade RCNN-RS models."""
def __init__(self, def __init__(self,
backbone: tf.keras.Model, backbone: tf.keras.Model,
...@@ -48,7 +48,7 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -48,7 +48,7 @@ class MaskRCNNModel(tf.keras.Model):
aspect_ratios: Optional[List[float]] = None, aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None, anchor_size: Optional[float] = None,
**kwargs): **kwargs):
"""Initializes the Mask R-CNN model. """Initializes the R-CNN(-RS) model.
Args: Args:
backbone: `tf.keras.Model`, the backbone network. backbone: `tf.keras.Model`, the backbone network.
...@@ -65,19 +65,18 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -65,19 +65,18 @@ class MaskRCNNModel(tf.keras.Model):
mask_roi_aligner: the ROI alginer for mask prediction. mask_roi_aligner: the ROI alginer for mask prediction.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models. prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over cascade_class_ensemble: if True, ensemble classification scores over all
all detection heads. detection heads.
min_level: Minimum level in output feature maps. min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps. max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added num_scales: A number representing intermediate scales added on each level.
on each level. For instances, num_scales=2 adds one additional For instances, num_scales=2 adds one additional intermediate anchor
intermediate anchor scales [2^0, 2^0.5] on each level. scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito aspect_ratios: A list representing the aspect raito anchors added on each
anchors added on each level. The number indicates the ratio of width to level. The number indicates the ratio of width to height. For instances,
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
on each scale level. anchor_size: A number representing the scale of size of the base anchor to
anchor_size: A number representing the scale of size of the base the feature stride 2^level.
anchor to the feature stride 2^level.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(MaskRCNNModel, self).__init__(**kwargs) super(MaskRCNNModel, self).__init__(**kwargs)
...@@ -143,6 +142,34 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -143,6 +142,34 @@ class MaskRCNNModel(tf.keras.Model):
gt_classes: Optional[tf.Tensor] = None, gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None, gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]: training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs, intermediate_outputs = self._call_box_outputs(
images=images, image_shape=image_shape, anchor_boxes=anchor_boxes,
gt_boxes=gt_boxes, gt_classes=gt_classes, training=training)
if not self._include_mask:
return model_outputs
model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs,
features=intermediate_outputs['features'],
current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
matched_gt_classes=intermediate_outputs['matched_gt_classes'],
gt_masks=gt_masks,
training=training)
model_outputs.update(model_mask_outputs)
return model_outputs
def _call_box_outputs(
self, images: tf.Tensor,
image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Tuple[
Mapping[str, tf.Tensor], Mapping[str, tf.Tensor]]:
"""Implementation of the Faster-RCNN logic for boxes."""
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
...@@ -239,9 +266,28 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -239,9 +266,28 @@ class MaskRCNNModel(tf.keras.Model):
'decoded_box_scores': detections['decoded_box_scores'] 'decoded_box_scores': detections['decoded_box_scores']
}) })
if not self._include_mask: intermediate_outputs = {
return model_outputs 'matched_gt_boxes': matched_gt_boxes,
'matched_gt_indices': matched_gt_indices,
'matched_gt_classes': matched_gt_classes,
'features': features,
'current_rois': current_rois,
}
return (model_outputs, intermediate_outputs)
def _call_mask_outputs(
self,
model_box_outputs: Mapping[str, tf.Tensor],
features: tf.Tensor,
current_rois: tf.Tensor,
matched_gt_indices: tf.Tensor,
matched_gt_boxes: tf.Tensor,
matched_gt_classes: tf.Tensor,
gt_masks: tf.Tensor,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
"""Implementation of Mask-RCNN mask prediction logic."""
model_outputs = dict(model_box_outputs)
if training: if training:
current_rois, roi_classes, roi_masks = self.mask_sampler( current_rois, roi_classes, roi_masks = self.mask_sampler(
current_rois, matched_gt_boxes, matched_gt_classes, current_rois, matched_gt_boxes, matched_gt_classes,
......
...@@ -384,7 +384,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -384,7 +384,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt')) ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone) partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint( partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched() save_dir)).expect_partial().assert_existing_objects_matched()
if include_mask: if include_mask:
......
...@@ -624,6 +624,76 @@ def bbox_overlap(boxes, gt_boxes): ...@@ -624,6 +624,76 @@ def bbox_overlap(boxes, gt_boxes):
return iou return iou
def bbox_generalized_overlap(boxes, gt_boxes):
"""Calculates the GIOU between proposal and ground truth boxes.
The generalized intersection of union is an adjustment of the traditional IOU
metric which provides continuous updates even for predictions with no overlap.
This metric is defined in https://giou.stanford.edu/GIoU.pdf. Note, some
`gt_boxes` may have been padded. The returned `giou` tensor for these boxes
will be -1.
Args:
boxes: a `Tensor` with a shape of [batch_size, N, 4]. N is the number of
proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form.
gt_boxes: a `Tensor` with a shape of [batch_size, max_num_instances, 4].
This tensor may have paddings with a negative value and will also be in
the [ymin, xmin, ymax, xmax] format.
Returns:
giou: a `Tensor` with as a shape of [batch_size, N, max_num_instances].
"""
with tf.name_scope('bbox_generalized_overlap'):
assert boxes.shape.as_list(
)[-1] == 4, 'Boxes must be defined by 4 coordinates.'
assert gt_boxes.shape.as_list(
)[-1] == 4, 'Groundtruth boxes must be defined by 4 coordinates.'
bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
value=boxes, num_or_size_splits=4, axis=2)
gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
value=gt_boxes, num_or_size_splits=4, axis=2)
# Calculates the hull area for each pair of boxes, with one from
# boxes and the other from gt_boxes.
# Outputs for coordinates are of shape [batch_size, N, max_num_instances]
h_xmin = tf.minimum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
h_xmax = tf.maximum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
h_ymin = tf.minimum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
h_ymax = tf.maximum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
h_area = tf.maximum((h_xmax - h_xmin), 0) * tf.maximum((h_ymax - h_ymin), 0)
# Add a small epsilon to avoid divide-by-zero.
h_area = h_area + 1e-8
# Calculates the intersection area.
i_xmin = tf.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
i_xmax = tf.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
i_ymin = tf.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
i_ymax = tf.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
i_area = tf.maximum((i_xmax - i_xmin), 0) * tf.maximum((i_ymax - i_ymin), 0)
# Calculates the union area.
bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
# Adds a small epsilon to avoid divide-by-zero.
u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
# Calculates IoU.
iou = i_area / u_area
# Calculates GIoU.
giou = iou - (h_area - u_area) / h_area
# Fills -1 for GIoU entries between the padded ground truth boxes.
gt_invalid_mask = tf.less(
tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
padding_mask = tf.broadcast_to(
tf.transpose(gt_invalid_mask, [0, 2, 1]), tf.shape(giou))
giou = tf.where(padding_mask, -tf.ones_like(giou), giou)
return giou
def box_matching(boxes, gt_boxes, gt_classes): def box_matching(boxes, gt_boxes, gt_classes):
"""Match boxes to groundtruth boxes. """Match boxes to groundtruth boxes.
......
...@@ -555,3 +555,183 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1): ...@@ -555,3 +555,183 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
lambda: masks) lambda: masks)
return image, normalized_boxes, masks return image, normalized_boxes, masks
def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
aspect_ratio_range,
min_overlap_params, max_retry):
"""Crops a random slice from the input image.
The function will correspondingly recompute the bounding boxes and filter out
outside boxes and their labels.
References:
[1] End-to-End Object Detection with Transformers
https://arxiv.org/abs/2005.12872
The preprocessing steps:
1. Sample a minimum IoU overlap.
2. For each trial, sample the new image width, height, and top-left corner.
3. Compute the IoUs of bounding boxes with the cropped image and retry if
the maximum IoU is below the sampled threshold.
4. Find boxes whose centers are in the cropped image.
5. Compute new bounding boxes in the cropped region and only select those
boxes' labels.
Args:
img: a 'Tensor' of shape [height, width, 3] representing the input image.
boxes: a 'Tensor' of shape [N, 4] representing the ground-truth bounding
boxes with (ymin, xmin, ymax, xmax).
labels: a 'Tensor' of shape [N,] representing the class labels of the boxes.
min_scale: a 'float' in [0.0, 1.0) indicating the lower bound of the random
scale variable.
aspect_ratio_range: a list of two 'float' that specifies the lower and upper
bound of the random aspect ratio.
min_overlap_params: a list of four 'float' representing the min value, max
value, step size, and offset for the minimum overlap sample.
max_retry: an 'int' representing the number of trials for cropping. If it is
exhausted, no cropping will be performed.
Returns:
img: a Tensor representing the random cropped image. Can be the
original image if max_retry is exhausted.
boxes: a Tensor representing the bounding boxes in the cropped image.
labels: a Tensor representing the new bounding boxes' labels.
"""
shape = tf.shape(img)
original_h = shape[0]
original_w = shape[1]
minval, maxval, step, offset = min_overlap_params
min_overlap = tf.math.floordiv(
tf.random.uniform([], minval=minval, maxval=maxval), step) * step - offset
min_overlap = tf.clip_by_value(min_overlap, 0.0, 1.1)
if min_overlap > 1.0:
return img, boxes, labels
aspect_ratio_low = aspect_ratio_range[0]
aspect_ratio_high = aspect_ratio_range[1]
for _ in tf.range(max_retry):
scale_h = tf.random.uniform([], min_scale, 1.0)
scale_w = tf.random.uniform([], min_scale, 1.0)
new_h = tf.cast(
scale_h * tf.cast(original_h, dtype=tf.float32), dtype=tf.int32)
new_w = tf.cast(
scale_w * tf.cast(original_w, dtype=tf.float32), dtype=tf.int32)
# Aspect ratio has to be in the prespecified range
aspect_ratio = new_h / new_w
if aspect_ratio_low > aspect_ratio or aspect_ratio > aspect_ratio_high:
continue
left = tf.random.uniform([], 0, original_w - new_w, dtype=tf.int32)
right = left + new_w
top = tf.random.uniform([], 0, original_h - new_h, dtype=tf.int32)
bottom = top + new_h
normalized_left = tf.cast(
left, dtype=tf.float32) / tf.cast(
original_w, dtype=tf.float32)
normalized_right = tf.cast(
right, dtype=tf.float32) / tf.cast(
original_w, dtype=tf.float32)
normalized_top = tf.cast(
top, dtype=tf.float32) / tf.cast(
original_h, dtype=tf.float32)
normalized_bottom = tf.cast(
bottom, dtype=tf.float32) / tf.cast(
original_h, dtype=tf.float32)
cropped_box = tf.expand_dims(
tf.stack([
normalized_top,
normalized_left,
normalized_bottom,
normalized_right,
]),
axis=0)
iou = box_ops.bbox_overlap(
tf.expand_dims(cropped_box, axis=0),
tf.expand_dims(boxes, axis=0)) # (1, 1, n_ground_truth)
iou = tf.squeeze(iou, axis=[0, 1])
# If not a single bounding box has a Jaccard overlap of greater than
# the minimum, try again
if tf.reduce_max(iou) < min_overlap:
continue
centroids = box_ops.yxyx_to_cycxhw(boxes)
mask = tf.math.logical_and(
tf.math.logical_and(centroids[:, 0] > normalized_top,
centroids[:, 0] < normalized_bottom),
tf.math.logical_and(centroids[:, 1] > normalized_left,
centroids[:, 1] < normalized_right))
# If not a single bounding box has its center in the crop, try again.
if tf.reduce_sum(tf.cast(mask, dtype=tf.int32)) > 0:
indices = tf.squeeze(tf.where(mask), axis=1)
filtered_boxes = tf.gather(boxes, indices)
boxes = tf.clip_by_value(
(filtered_boxes[..., :] * tf.cast(
tf.stack([original_h, original_w, original_h, original_w]),
dtype=tf.float32) -
tf.cast(tf.stack([top, left, top, left]), dtype=tf.float32)) /
tf.cast(tf.stack([new_h, new_w, new_h, new_w]), dtype=tf.float32),
0.0, 1.0)
img = tf.image.crop_to_bounding_box(img, top, left, bottom - top,
right - left)
labels = tf.gather(labels, indices)
break
return img, boxes, labels
def random_crop(image,
boxes,
labels,
min_scale=0.3,
aspect_ratio_range=(0.5, 2.0),
min_overlap_params=(0.0, 1.4, 0.2, 0.1),
max_retry=50,
seed=None):
"""Randomly crop the image and boxes, filtering labels.
Args:
image: a 'Tensor' of shape [height, width, 3] representing the input image.
boxes: a 'Tensor' of shape [N, 4] representing the ground-truth bounding
boxes with (ymin, xmin, ymax, xmax).
labels: a 'Tensor' of shape [N,] representing the class labels of the boxes.
min_scale: a 'float' in [0.0, 1.0) indicating the lower bound of the random
scale variable.
aspect_ratio_range: a list of two 'float' that specifies the lower and upper
bound of the random aspect ratio.
min_overlap_params: a list of four 'float' representing the min value, max
value, step size, and offset for the minimum overlap sample.
max_retry: an 'int' representing the number of trials for cropping. If it is
exhausted, no cropping will be performed.
seed: the random number seed of int, but could be None.
Returns:
image: a Tensor representing the random cropped image. Can be the
original image if max_retry is exhausted.
boxes: a Tensor representing the bounding boxes in the cropped image.
labels: a Tensor representing the new bounding boxes' labels.
"""
with tf.name_scope('random_crop'):
do_crop = tf.greater(tf.random.uniform([], seed=seed), 0.5)
if do_crop:
return random_crop_image_with_boxes_and_labels(image, boxes, labels,
min_scale,
aspect_ratio_range,
min_overlap_params,
max_retry)
else:
return image, boxes, labels
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tests for preprocess_ops.py.""" """Tests for preprocess_ops.py."""
import io import io
...@@ -42,7 +41,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,7 +41,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
([12, 2], 10), ([12, 2], 10),
([13, 2, 3], 10), ([13, 2, 3], 10),
) )
def testPadToFixedSize(self, input_shape, output_size): def test_pad_to_fixed_size(self, input_shape, output_size):
# Copies input shape to padding shape. # Copies input shape to padding shape.
clip_shape = input_shape[:] clip_shape = input_shape[:]
clip_shape[0] = min(output_size, clip_shape[0]) clip_shape[0] = min(output_size, clip_shape[0])
...@@ -63,16 +62,11 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -63,16 +62,11 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
(100, 256, 128, 256, 32, 1.0, 1.0, 128, 256), (100, 256, 128, 256, 32, 1.0, 1.0, 128, 256),
(200, 512, 200, 128, 32, 0.25, 0.25, 224, 128), (200, 512, 200, 128, 32, 0.25, 0.25, 224, 128),
) )
def testResizeAndCropImageRectangluarCase(self, def test_resize_and_crop_image_rectangluar_case(self, input_height,
input_height, input_width, desired_height,
input_width, desired_width, stride,
desired_height, scale_y, scale_x,
desired_width, output_height, output_width):
stride,
scale_y,
scale_x,
output_height,
output_width):
image = tf.convert_to_tensor( image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3)) np.random.rand(input_height, input_width, 3))
...@@ -98,16 +92,10 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -98,16 +92,10 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
(100, 200, 220, 220, 32, 1.1, 1.1, 224, 224), (100, 200, 220, 220, 32, 1.1, 1.1, 224, 224),
(512, 512, 1024, 1024, 32, 2.0, 2.0, 1024, 1024), (512, 512, 1024, 1024, 32, 2.0, 2.0, 1024, 1024),
) )
def testResizeAndCropImageSquareCase(self, def test_resize_and_crop_image_square_case(self, input_height, input_width,
input_height, desired_height, desired_width,
input_width, stride, scale_y, scale_x,
desired_height, output_height, output_width):
desired_width,
stride,
scale_y,
scale_x,
output_height,
output_width):
image = tf.convert_to_tensor( image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3)) np.random.rand(input_height, input_width, 3))
...@@ -135,18 +123,10 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -135,18 +123,10 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
(100, 200, 80, 100, 32, 0.5, 0.5, 50, 100, 96, 128), (100, 200, 80, 100, 32, 0.5, 0.5, 50, 100, 96, 128),
(200, 100, 80, 100, 32, 0.5, 0.5, 100, 50, 128, 96), (200, 100, 80, 100, 32, 0.5, 0.5, 100, 50, 128, 96),
) )
def testResizeAndCropImageV2(self, def test_resize_and_crop_image_v2(self, input_height, input_width, short_side,
input_height, long_side, stride, scale_y, scale_x,
input_width, desired_height, desired_width,
short_side, output_height, output_width):
long_side,
stride,
scale_y,
scale_x,
desired_height,
desired_width,
output_height,
output_width):
image = tf.convert_to_tensor( image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3)) np.random.rand(input_height, input_width, 3))
image_shape = tf.shape(image)[0:2] image_shape = tf.shape(image)[0:2]
...@@ -176,9 +156,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -176,9 +156,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(400, 600), (600, 400), (400, 600), (600, 400),
) )
def testCenterCropImage(self, def test_center_crop_image(self, input_height, input_width):
input_height,
input_width):
image = tf.convert_to_tensor( image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3)) np.random.rand(input_height, input_width, 3))
cropped_image = preprocess_ops.center_crop_image(image) cropped_image = preprocess_ops.center_crop_image(image)
...@@ -188,9 +166,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -188,9 +166,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(400, 600), (600, 400), (400, 600), (600, 400),
) )
def testCenterCropImageV2(self, def test_center_crop_image_v2(self, input_height, input_width):
input_height,
input_width):
image_bytes = tf.constant( image_bytes = tf.constant(
_encode_image( _encode_image(
np.uint8(np.random.rand(input_height, input_width, 3) * 255), np.uint8(np.random.rand(input_height, input_width, 3) * 255),
...@@ -204,9 +180,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -204,9 +180,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(400, 600), (600, 400), (400, 600), (600, 400),
) )
def testRandomCropImage(self, def test_random_crop_image(self, input_height, input_width):
input_height,
input_width):
image = tf.convert_to_tensor( image = tf.convert_to_tensor(
np.random.rand(input_height, input_width, 3)) np.random.rand(input_height, input_width, 3))
_ = preprocess_ops.random_crop_image(image) _ = preprocess_ops.random_crop_image(image)
...@@ -214,9 +188,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -214,9 +188,7 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(400, 600), (600, 400), (400, 600), (600, 400),
) )
def testRandomCropImageV2(self, def test_random_crop_image_v2(self, input_height, input_width):
input_height,
input_width):
image_bytes = tf.constant( image_bytes = tf.constant(
_encode_image( _encode_image(
np.uint8(np.random.rand(input_height, input_width, 3) * 255), np.uint8(np.random.rand(input_height, input_width, 3) * 255),
...@@ -225,6 +197,21 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -225,6 +197,21 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_ = preprocess_ops.random_crop_image_v2( _ = preprocess_ops.random_crop_image_v2(
image_bytes, tf.constant([input_height, input_width, 3], tf.int32)) image_bytes, tf.constant([input_height, input_width, 3], tf.int32))
@parameterized.parameters((640, 640, 20), (1280, 1280, 30))
def test_random_crop(self, input_height, input_width, num_boxes):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
boxes_height = np.random.randint(0, input_height, size=(num_boxes, 1))
top = np.random.randint(0, high=(input_height - boxes_height))
down = top + boxes_height
boxes_width = np.random.randint(0, input_width, size=(num_boxes, 1))
left = np.random.randint(0, high=(input_width - boxes_width))
right = left + boxes_width
boxes = tf.constant(
np.concatenate([top, left, down, right], axis=-1), tf.float32)
labels = tf.constant(
np.random.randint(low=0, high=num_boxes, size=(num_boxes,)), tf.int64)
_ = preprocess_ops.random_crop(image, boxes, labels)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -82,6 +82,8 @@ $ python3 -m official.vision.beta.projects.deepmac_maskrcnn.train \ ...@@ -82,6 +82,8 @@ $ python3 -m official.vision.beta.projects.deepmac_maskrcnn.train \
``` ```
`CONFIG_FILE` can be any file in the `configs/experiments` directory. `CONFIG_FILE` can be any file in the `configs/experiments` directory.
When using SpineNet models, please specify
`--experiment=deep_mask_head_rcnn_spinenet_coco`
**Note:** The default eval batch size of 32 discards some samples during **Note:** The default eval batch size of 32 discards some samples during
validation. For accurate vaidation statistics, launch a dedicated eval job on validation. For accurate vaidation statistics, launch a dedicated eval job on
...@@ -93,11 +95,12 @@ In the following table, we report the Mask mAP of our models on the non-VOC ...@@ -93,11 +95,12 @@ In the following table, we report the Mask mAP of our models on the non-VOC
classes when only training with masks for the VOC calsses. Performance is classes when only training with masks for the VOC calsses. Performance is
measured on the `coco-val2017` set. measured on the `coco-val2017` set.
Backbone | Mask head | Config name | Mask mAP Backbone | Mask head | Config name | Mask mAP
:--------- | :----------- | :--------------------------------------- | -------: :------------| :----------- | :-----------------------------------------------| -------:
ResNet-50 | Default | `deep_mask_head_rcnn_voc_r50.yaml` | 25.9 ResNet-50 | Default | `deep_mask_head_rcnn_voc_r50.yaml` | 25.9
ResNet-50 | Hourglass-52 | `deep_mask_head_rcnn_voc_r50_hg52.yaml` | 33.1 ResNet-50 | Hourglass-52 | `deep_mask_head_rcnn_voc_r50_hg52.yaml` | 33.1
ResNet-101 | Hourglass-52 | `deep_mask_head_rcnn_voc_r101_hg52.yaml` | 34.4 ResNet-101 | Hourglass-52 | `deep_mask_head_rcnn_voc_r101_hg52.yaml` | 34.4
SpienNet-143 | Hourglass-52 | `deep_mask_head_rcnn_voc_spinenet143_hg52.yaml` | 38.7
## See also ## See also
......
...@@ -22,6 +22,9 @@ import dataclasses ...@@ -22,6 +22,9 @@ import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
from official.vision.beta.configs import maskrcnn as maskrcnn_config from official.vision.beta.configs import maskrcnn as maskrcnn_config
from official.vision.beta.configs import retinanet as retinanet_config from official.vision.beta.configs import retinanet as retinanet_config
...@@ -59,20 +62,18 @@ def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -59,20 +62,18 @@ def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE, annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'instances_val2017.json'), 'instances_val2017.json'),
model=DeepMaskHeadRCNN( model=DeepMaskHeadRCNN(
num_classes=91, num_classes=91, input_size=[1024, 1024, 3], include_mask=True), # pytype: disable=wrong-keyword-args
input_size=[1024, 1024, 3],
include_mask=True), # pytype: disable=wrong-keyword-args
losses=maskrcnn_config.Losses(l2_weight_decay=0.00004), losses=maskrcnn_config.Losses(l2_weight_decay=0.00004),
train_data=maskrcnn_config.DataConfig( train_data=maskrcnn_config.DataConfig(
input_path=os.path.join( input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
maskrcnn_config.COCO_INPUT_PATH_BASE, 'train*'), 'train*'),
is_training=True, is_training=True,
global_batch_size=global_batch_size, global_batch_size=global_batch_size,
parser=maskrcnn_config.Parser( parser=maskrcnn_config.Parser(
aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)), aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
validation_data=maskrcnn_config.DataConfig( validation_data=maskrcnn_config.DataConfig(
input_path=os.path.join( input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
maskrcnn_config.COCO_INPUT_PATH_BASE, 'val*'), 'val*'),
is_training=False, is_training=False,
global_batch_size=8)), # pytype: disable=wrong-keyword-args global_batch_size=8)), # pytype: disable=wrong-keyword-args
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
...@@ -110,3 +111,87 @@ def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -110,3 +111,87 @@ def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('deep_mask_head_rcnn_spinenet_coco')
def deep_mask_head_rcnn_spinenet_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Mask R-CNN with SpineNet backbone."""
steps_per_epoch = 463
coco_val_samples = 5000
train_batch_size = 256
eval_batch_size = 8
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=DeepMaskHeadRCNNTask(
annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'instances_val2017.json'), # pytype: disable=wrong-keyword-args
model=DeepMaskHeadRCNN(
backbone=backbones.Backbone(
type='spinenet',
spinenet=backbones.SpineNet(
model_id='49',
min_level=3,
max_level=7,
)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
anchor=maskrcnn_config.Anchor(anchor_size=3),
norm_activation=common.NormActivation(use_sync_bn=True),
num_classes=91,
input_size=[640, 640, 3],
min_level=3,
max_level=7,
include_mask=True), # pytype: disable=wrong-keyword-args
losses=maskrcnn_config.Losses(l2_weight_decay=0.00004),
train_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=maskrcnn_config.Parser(
aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)),
validation_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)), # pytype: disable=wrong-keyword-args
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 350,
validation_steps=coco_val_samples // eval_batch_size,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [
steps_per_epoch * 320, steps_per_epoch * 340
],
'values': [0.32, 0.032, 0.0032],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2000,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.model.min_level == task.model.backbone.spinenet.min_level',
'task.model.max_level == task.model.backbone.spinenet.max_level',
])
return config
...@@ -25,6 +25,10 @@ class DeepMaskHeadRcnnConfigTest(tf.test.TestCase): ...@@ -25,6 +25,10 @@ class DeepMaskHeadRcnnConfigTest(tf.test.TestCase):
config = deep_mask_head_rcnn.deep_mask_head_rcnn_resnetfpn_coco() config = deep_mask_head_rcnn.deep_mask_head_rcnn_resnetfpn_coco()
self.assertIsInstance(config.task, deep_mask_head_rcnn.DeepMaskHeadRCNNTask) self.assertIsInstance(config.task, deep_mask_head_rcnn.DeepMaskHeadRCNNTask)
def test_config_spinenet(self):
config = deep_mask_head_rcnn.deep_mask_head_rcnn_spinenet_coco()
self.assertIsInstance(config.task, deep_mask_head_rcnn.DeepMaskHeadRCNNTask)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Expect to reach: box mAP: 49.3%, mask mAP: 43.4% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
allowed_mask_class_ids: [
8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
57, 58, 59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
85, 86, 87, 88, 89, 90
]
per_category_metrics: true
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
mask_head:
class_agnostic: true
convnet_variant: 'hourglass52'
num_filters: 64
mask_roi_aligner:
crop_size: 32
use_gt_boxes_for_masks: true
anchor:
anchor_size: 4.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1280, 1280, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '143'
type: 'spinenet'
decoder:
type: 'identity'
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
# Expect to reach: box mAP: 49.3%, mask mAP: 43.4% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
allowed_mask_class_ids: [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
per_category_metrics: true
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
mask_head:
class_agnostic: true
convnet_variant: 'hourglass52'
num_filters: 64
mask_roi_aligner:
crop_size: 32
use_gt_boxes_for_masks: true
anchor:
anchor_size: 4.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1280, 1280, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '143'
type: 'spinenet'
decoder:
type: 'identity'
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
"""Mask R-CNN model.""" """Mask R-CNN model."""
from typing import List, Mapping, Optional, Union
# Import libraries # Import libraries
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import box_ops from official.vision.beta.modeling import maskrcnn_model
def resize_as(source, size): def resize_as(source, size):
...@@ -30,21 +32,30 @@ def resize_as(source, size): ...@@ -30,21 +32,30 @@ def resize_as(source, size):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class DeepMaskRCNNModel(tf.keras.Model): class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Mask R-CNN model.""" """The Mask R-CNN model."""
def __init__(self, def __init__(self,
backbone, backbone: tf.keras.Model,
decoder, decoder: tf.keras.Model,
rpn_head, rpn_head: tf.keras.layers.Layer,
detection_head, detection_head: Union[tf.keras.layers.Layer,
roi_generator, List[tf.keras.layers.Layer]],
roi_sampler, roi_generator: tf.keras.layers.Layer,
roi_aligner, roi_sampler: Union[tf.keras.layers.Layer,
detection_generator, List[tf.keras.layers.Layer]],
mask_head=None, roi_aligner: tf.keras.layers.Layer,
mask_sampler=None, detection_generator: tf.keras.layers.Layer,
mask_roi_aligner=None, mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
use_gt_boxes_for_masks=False, use_gt_boxes_for_masks=False,
**kwargs): **kwargs):
"""Initializes the Mask R-CNN model. """Initializes the Mask R-CNN model.
...@@ -53,122 +64,99 @@ class DeepMaskRCNNModel(tf.keras.Model): ...@@ -53,122 +64,99 @@ class DeepMaskRCNNModel(tf.keras.Model):
backbone: `tf.keras.Model`, the backbone network. backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network. decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head. rpn_head: the RPN head.
detection_head: the detection head. detection_head: the detection head or a list of heads.
roi_generator: the ROI generator. roi_generator: the ROI generator.
roi_sampler: the ROI sampler. roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner. roi_aligner: the ROI aligner.
detection_generator: the detection generator. detection_generator: the detection generator.
mask_head: the mask head. mask_head: the mask head.
mask_sampler: the mask sampler. mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction. mask_roi_aligner: the ROI alginer for mask prediction.
use_gt_boxes_for_masks: bool, if set, crop using groundtruth boxes class_agnostic_bbox_pred: if True, perform class agnostic bounding box
instead of proposals for training mask head prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over all
detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added on each level.
For instances, num_scales=2 adds one additional intermediate anchor
scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito anchors added on each
level. The number indicates the ratio of width to height. For instances,
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
use_gt_boxes_for_masks: bool, if set, crop using groundtruth boxes instead
of proposals for training mask head
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(DeepMaskRCNNModel, self).__init__(**kwargs) super(DeepMaskRCNNModel, self).__init__(
self._config_dict = { backbone=backbone,
'backbone': backbone, decoder=decoder,
'decoder': decoder, rpn_head=rpn_head,
'rpn_head': rpn_head, detection_head=detection_head,
'detection_head': detection_head, roi_generator=roi_generator,
'roi_generator': roi_generator, roi_sampler=roi_sampler,
'roi_sampler': roi_sampler, roi_aligner=roi_aligner,
'roi_aligner': roi_aligner, detection_generator=detection_generator,
'detection_generator': detection_generator, mask_head=mask_head,
'mask_head': mask_head, mask_sampler=mask_sampler,
'mask_sampler': mask_sampler, mask_roi_aligner=mask_roi_aligner,
'mask_roi_aligner': mask_roi_aligner, class_agnostic_bbox_pred=class_agnostic_bbox_pred,
'use_gt_boxes_for_masks': use_gt_boxes_for_masks cascade_class_ensemble=cascade_class_ensemble,
} min_level=min_level,
self.backbone = backbone max_level=max_level,
self.decoder = decoder num_scales=num_scales,
self.rpn_head = rpn_head aspect_ratios=aspect_ratios,
self.detection_head = detection_head anchor_size=anchor_size,
self.roi_generator = roi_generator **kwargs)
self.roi_sampler = roi_sampler
self.roi_aligner = roi_aligner self._config_dict['use_gt_boxes_for_masks'] = use_gt_boxes_for_masks
self.detection_generator = detection_generator
self._include_mask = mask_head is not None
self.mask_head = mask_head
if self._include_mask and mask_sampler is None:
raise ValueError('`mask_sampler` is not provided in Mask R-CNN.')
self.mask_sampler = mask_sampler
if self._include_mask and mask_roi_aligner is None:
raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
self.mask_roi_aligner = mask_roi_aligner
def call(self, def call(self,
images, images: tf.Tensor,
image_shape, image_shape: tf.Tensor,
anchor_boxes=None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes=None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes=None, gt_classes: Optional[tf.Tensor] = None,
gt_masks=None, gt_masks: Optional[tf.Tensor] = None,
training=None): training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = {}
model_outputs, intermediate_outputs = self._call_box_outputs(
# Feature extraction. images=images, image_shape=image_shape, anchor_boxes=anchor_boxes,
features = self.backbone(images) gt_boxes=gt_boxes, gt_classes=gt_classes, training=training)
if self.decoder:
features = self.decoder(features)
# Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features)
model_outputs.update({
'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores
})
# Generate RoIs.
rois, _ = self.roi_generator(
rpn_boxes, rpn_scores, anchor_boxes, image_shape, training)
if training:
rois = tf.stop_gradient(rois)
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
self.roi_sampler(rois, gt_boxes, gt_classes))
# Assign target for the 2nd stage classification.
box_targets = box_ops.encode_boxes(
matched_gt_boxes, rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]),
tf.zeros_like(box_targets),
box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
# RoI align.
roi_features = self.roi_aligner(features, rois)
# Detection head.
raw_scores, raw_boxes = self.detection_head(roi_features)
if training:
model_outputs.update({
'class_outputs': raw_scores,
'box_outputs': raw_boxes,
})
else:
# Post-processing.
detections = self.detection_generator(
raw_boxes, raw_scores, rois, image_shape)
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections'],
})
if not self._include_mask: if not self._include_mask:
return model_outputs return model_outputs
model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs,
features=intermediate_outputs['features'],
current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
matched_gt_classes=intermediate_outputs['matched_gt_classes'],
gt_masks=gt_masks,
gt_classes=gt_classes,
gt_boxes=gt_boxes,
training=training)
model_outputs.update(model_mask_outputs)
return model_outputs
def _call_mask_outputs(
self,
model_box_outputs: Mapping[str, tf.Tensor],
features: tf.Tensor,
current_rois: tf.Tensor,
matched_gt_indices: tf.Tensor,
matched_gt_boxes: tf.Tensor,
matched_gt_classes: tf.Tensor,
gt_masks: tf.Tensor,
gt_classes: tf.Tensor,
gt_boxes: tf.Tensor,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = dict(model_box_outputs)
if training: if training:
if self._config_dict['use_gt_boxes_for_masks']: if self._config_dict['use_gt_boxes_for_masks']:
mask_size = ( mask_size = (
...@@ -184,11 +172,8 @@ class DeepMaskRCNNModel(tf.keras.Model): ...@@ -184,11 +172,8 @@ class DeepMaskRCNNModel(tf.keras.Model):
}) })
else: else:
rois, roi_classes, roi_masks = self.mask_sampler( rois, roi_classes, roi_masks = self.mask_sampler(
rois, current_rois, matched_gt_boxes, matched_gt_classes,
matched_gt_boxes, matched_gt_indices, gt_masks)
matched_gt_classes,
matched_gt_indices,
gt_masks)
roi_masks = tf.stop_gradient(roi_masks) roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({ model_outputs.update({
'mask_class_targets': roi_classes, 'mask_class_targets': roi_classes,
...@@ -219,24 +204,3 @@ class DeepMaskRCNNModel(tf.keras.Model): ...@@ -219,24 +204,3 @@ class DeepMaskRCNNModel(tf.keras.Model):
'detection_masks': tf.math.sigmoid(raw_masks), 'detection_masks': tf.math.sigmoid(raw_masks),
}) })
return model_outputs return model_outputs
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(
backbone=self.backbone,
rpn_head=self.rpn_head,
detection_head=self.detection_head)
if self.decoder is not None:
items.update(decoder=self.decoder)
if self._include_mask:
items.update(mask_head=self.mask_head)
return items
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# TF Vision Example Project
This is a minimal example project to demonstrate how to use TF Model Garden's
building blocks to implement a new vision project from scratch.
Below we use classification as an example. We will walk you through the process
of creating a new projects leveraging existing components, such as tasks, data
loaders, models, etc. You will get better understanding of these components by
going through the process. You can also refer to the docstring of corresponding
components to get more information.
## Create Model
In
[example_model.py](example_model.py),
we show how to create a new model. The `ExampleModel` is a subclass of
`tf.keras.Model` that defines necessary parameters. Here, you need to have
`input_specs` to specify the input shape and dimensions, and build layers within
constructor:
```python
class ExampleModel(tf.keras.Model):
def __init__(
self,
num_classes: int,
input_specs: tf.keras.layers.InputSpec = tf.keras.layers.InputSpec(
shape=[None, None, None, 3]),
**kwargs):
# Build layers.
```
Given the `ExampleModel`, you can define a function that takes a model config as
input and return an `ExampleModel` instance, similar as
[build_example_model](example_model.py#L80).
As a simple example, we define a single model. However, you can split the model
implementation to individual components, such as backbones, decoders, heads, as
what we do
[here](https://github.com/tensorflow/models/blob/master/official/vision/beta/modeling).
And then in `build_example_model` function, you can hook up these components
together to obtain your full model.
## Create Dataloader
A dataloader reads, decodes and parses the input data. We have created various
[dataloaders](https://github.com/tensorflow/models/blob/master/official/vision/beta/dataloaders)
to handle standard input formats for classification, detection and segmentation.
If you have non-standard or complex data, you may want to create your own
dataloader. It contains a `Decoder` and a `Parser`.
- The
[Decoder](example_input.py#L33)
decodes a TF Example record and returns a dictionary of decoded tensors:
```python
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self):
"""Initializes the decoder.
The constructor defines the mapping between the field name and the value
from an input tf.Example. For example, we define two fields for image bytes
and labels. There is no limit on the number of fields to decode.
"""
self._keys_to_features = {
'image/encoded':
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/class/label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1)
}
```
- The
[Parser](example_input.py#L68)
parses the decoded tensors and performs pre-processing to the input data,
such as image decoding, augmentation and resizing, etc. It should have
`_parse_train_data` and `_parse_eval_data` functions, in which the processed
images and labels are returned.
## Create Config
Next you will define configs for your project. All configs are defined as
`dataclass` objects, and can have default parameter values.
First, you will define your
[`ExampleDataConfig`](example_config.py#L27).
It inherits from `config_definitions.DataConfig` that already defines a few
common fields, like `input_path`, `file_type`, `global_batch_size`, etc. You can
add more fields in your own config as needed.
You can then define you model config
[`ExampleModel`](example_config.py#L39)
that inherits from `hyperparams.Config`. Expose your own model parameters here.
You can then define your `Loss` and `Evaluation` configs.
Next, you will put all the above configs into an
[`ExampleTask`](example_config.py#L56)
config. Here you list the configs for your data, model, loss, and evaluation,
etc.
Finally, you can define a
[`tf_vision_example_experiment`](example_config.py#L66),
which creates a template for your experiments and fills with default parameters.
These default parameter values can be overridden by a YAML file, like
[example_config_tpu.yaml](example_config_tpu.yaml).
Also, make sure you give a unique name to your experiment template by the
decorator:
```python
@exp_factory.register_config_factory('tf_vision_example_experiment')
def tf_vision_example_experiment() -> cfg.ExperimentConfig:
"""Definition of a full example experiment."""
# Create and return experiment template.
```
## Create Task
A task is a class that encapsules the logic of loading data, building models,
performing one-step training and validation, etc. It connects all components
together and is called by the base
[Trainer](https://github.com/tensorflow/models/blob/master/official/core/base_trainer.py).
You can create your own task by inheriting from base
[Task](https://github.com/tensorflow/models/blob/master/official/core/base_task.py),
or from one of the
[tasks](https://github.com/tensorflow/models/blob/master/official/vision/beta/tasks/)
we already defined, if most of the operations can be reused. An `ExampleTask`
inheriting from
[ImageClassificationTask](https://github.com/tensorflow/models/blob/master/official/vision/beta/tasks/image_classification.py#L32)
can be found
[here](example_task.py).
We will go through each important components in the task in the following.
- `build_model`: you can instantiate a model you have defined above. It is
also good practice to run forward pass with a dummy input to ensure layers
within the model are properly initialized.
- `build_inputs`: here you can instantiate a Decoder object and a Parser
object. They are used to create an `InputReader` that will generate a
`tf.data.Dataset` object.
- `build_losses`: it takes groundtruth labels and model outputs as input, and
computes the loss. It will be called in `train_step` and `validation_step`.
You can also define different losses for training and validation, for
example, `build_train_losses` and `build_validation_losses`. Just make sure
they are called by the corresponding functions properly.
- `build_metrics`: here you can define your own metrics. It should return a
list of `tf.keras.metrics.Metric` objects. You can create your own metric
class by subclassing `tf.keras.metrics.Metric`.
- `train_step` and `validation_step`: they perform one-step training and
validation. They take one batch of training/validation data, run forward
pass, gather losses and update metrics. They assume the data format is
consistency with that from the `Parser` output. `train_step` also contains
backward pass to update model weights.
## Import registry
To use your custom dataloaders, models, tasks, etc., you will need to register
them properly. The recommended way is to have a single file with all relevant
files imported, for example,
[registry_imports.py](registry_imports.py).
You can see in this file we import all our custom components:
```python
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.example import example_config
from official.vision.beta.projects.example import example_input
from official.vision.beta.projects.example import example_model
from official.vision.beta.projects.example import example_task
```
## Training
You can create your own trainer by branching from our core
[trainer](https://github.com/tensorflow/models/blob/master/official/vision/beta/train.py).
Just make sure you import the registry like this:
```python
from official.vision.beta.projects.example import registry_imports # pylint: disable=unused-import
```
You can run training locally for testing purpose:
```bash
# Assume you are under official/vision/beta/projects.
python3 example/train.py \
--experiment=tf_vision_example_experiment \
--config_file=${PWD}/example/example_config_local.yaml \
--mode=train \
--model_dir=/tmp/tfvision_test/
```
It can also run on Google Cloud using Cloud TPU.
[Here](https://cloud.google.com/tpu/docs/how-to) is the instruction of using
Cloud TPU and here is a more detailed
[tutorial](https://cloud.google.com/tpu/docs/tutorials/resnet-rs-2.x) of
training a ResNet-RS model. Following the instructions to set up Cloud TPU and
launch training by:
```bash
EXP_TYPE=tf_vision_example_experiment # This should match the registered name of your experiment template.
EXP_NAME=exp_001 # You can give any name to the experiment.
TPU_NAME=experiment01
# Now launch the experiment.
python3 example/train.py \
--experiment=$EXP_TYPE \
--mode=train \
--tpu=$TPU_NAME \
--model_dir=/tmp/tfvision_test/
--config_file=third_party/tensorflow_models/official/vision/beta/projects/example/example_config_tpu.yaml
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment