Commit 7359586f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents c594cecf a78b05b9
...@@ -27,7 +27,9 @@ import functools ...@@ -27,7 +27,9 @@ import functools
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import context_rcnn_lib from object_detection.meta_architectures import context_rcnn_lib
from object_detection.meta_architectures import context_rcnn_lib_tf2
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import tf_version
class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
...@@ -264,11 +266,17 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -264,11 +266,17 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
return_raw_detections_during_predict), return_raw_detections_during_predict),
output_final_box_features=output_final_box_features) output_final_box_features=output_final_box_features)
self._context_feature_extract_fn = functools.partial( if tf_version.is_tf1():
context_rcnn_lib.compute_box_context_attention, self._context_feature_extract_fn = functools.partial(
bottleneck_dimension=attention_bottleneck_dimension, context_rcnn_lib.compute_box_context_attention,
attention_temperature=attention_temperature, bottleneck_dimension=attention_bottleneck_dimension,
is_training=is_training) attention_temperature=attention_temperature,
is_training=is_training)
else:
self._context_feature_extract_fn = context_rcnn_lib_tf2.AttentionBlock(
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
@staticmethod @staticmethod
def get_side_inputs(features): def get_side_inputs(features):
...@@ -323,6 +331,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -323,6 +331,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
Returns: Returns:
A float32 Tensor with shape [K, new_height, new_width, depth]. A float32 Tensor with shape [K, new_height, new_width, depth].
""" """
box_features = self._crop_and_resize_fn( box_features = self._crop_and_resize_fn(
[features_to_crop], proposal_boxes_normalized, None, [features_to_crop], proposal_boxes_normalized, None,
[self._initial_crop_size, self._initial_crop_size]) [self._initial_crop_size, self._initial_crop_size])
......
...@@ -109,7 +109,6 @@ class FakeFasterRCNNKerasFeatureExtractor( ...@@ -109,7 +109,6 @@ class FakeFasterRCNNKerasFeatureExtractor(
]) ])
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
def _get_model(self, box_predictor, **common_kwargs): def _get_model(self, box_predictor, **common_kwargs):
...@@ -440,15 +439,16 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -440,15 +439,16 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
masks_are_class_agnostic=masks_are_class_agnostic, masks_are_class_agnostic=masks_are_class_agnostic,
share_box_across_classes=share_box_across_classes), **common_kwargs) share_box_across_classes=share_box_across_classes), **common_kwargs)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
@mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib') @mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib')
def test_prediction_mock(self, mock_context_rcnn_lib): def test_prediction_mock_tf1(self, mock_context_rcnn_lib_v1):
"""Mocks the context_rcnn_lib module to test the prediction. """Mocks the context_rcnn_lib_v1 module to test the prediction.
Using mock object so that we can ensure compute_box_context_attention is Using mock object so that we can ensure compute_box_context_attention is
called in side the prediction function. called in side the prediction function.
Args: Args:
mock_context_rcnn_lib: mock module for the context_rcnn_lib. mock_context_rcnn_lib_v1: mock module for the context_rcnn_lib_v1.
""" """
model = self._build_model( model = self._build_model(
is_training=False, is_training=False,
...@@ -457,7 +457,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -457,7 +457,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
num_classes=42) num_classes=42)
mock_tensor = tf.ones([2, 8, 3, 3, 3], tf.float32) mock_tensor = tf.ones([2, 8, 3, 3, 3], tf.float32)
mock_context_rcnn_lib.compute_box_context_attention.return_value = mock_tensor mock_context_rcnn_lib_v1.compute_box_context_attention.return_value = mock_tensor
inputs_shape = (2, 20, 20, 3) inputs_shape = (2, 20, 20, 3)
inputs = tf.cast( inputs = tf.cast(
tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32), tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32),
...@@ -479,7 +479,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -479,7 +479,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
side_inputs = model.get_side_inputs(features) side_inputs = model.get_side_inputs(features)
_ = model.predict(preprocessed_inputs, true_image_shapes, **side_inputs) _ = model.predict(preprocessed_inputs, true_image_shapes, **side_inputs)
mock_context_rcnn_lib.compute_box_context_attention.assert_called_once() mock_context_rcnn_lib_v1.compute_box_context_attention.assert_called_once()
@parameterized.named_parameters( @parameterized.named_parameters(
{'testcase_name': 'static_shapes', 'static_shapes': True}, {'testcase_name': 'static_shapes', 'static_shapes': True},
...@@ -518,7 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -518,7 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
} }
side_inputs = model.get_side_inputs(features) side_inputs = model.get_side_inputs(features)
prediction_dict = model.predict(preprocessed_inputs, true_image_shapes, prediction_dict = model.predict(preprocessed_inputs, true_image_shapes,
**side_inputs) **side_inputs)
return (prediction_dict['rpn_box_predictor_features'], return (prediction_dict['rpn_box_predictor_features'],
......
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import time import time
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util from object_detection import eval_util
from object_detection import inputs from object_detection import inputs
...@@ -117,7 +118,8 @@ def _compute_losses_and_predictions_dicts( ...@@ -117,7 +118,8 @@ def _compute_losses_and_predictions_dicts(
prediction_dict = model.predict( prediction_dict = model.predict(
preprocessed_images, preprocessed_images,
features[fields.InputDataFields.true_image_shape]) features[fields.InputDataFields.true_image_shape],
**model.get_side_inputs(features))
prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict) prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
losses_dict = model.loss( losses_dict = model.loss(
...@@ -413,8 +415,9 @@ def train_loop( ...@@ -413,8 +415,9 @@ def train_loop(
train_steps=None, train_steps=None,
use_tpu=False, use_tpu=False,
save_final_config=False, save_final_config=False,
checkpoint_every_n=5000, checkpoint_every_n=1000,
checkpoint_max_to_keep=7, checkpoint_max_to_keep=7,
record_summaries=True,
**kwargs): **kwargs):
"""Trains a model using eager + functions. """Trains a model using eager + functions.
...@@ -444,6 +447,7 @@ def train_loop( ...@@ -444,6 +447,7 @@ def train_loop(
Checkpoint every n training steps. Checkpoint every n training steps.
checkpoint_max_to_keep: checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory. int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override. **kwargs: Additional keyword arguments for configuration override.
""" """
## Parse the configs ## Parse the configs
...@@ -530,8 +534,11 @@ def train_loop( ...@@ -530,8 +534,11 @@ def train_loop(
# is the chief. # is the chief.
summary_writer_filepath = get_filepath(strategy, summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train')) os.path.join(model_dir, 'train'))
summary_writer = tf.compat.v2.summary.create_file_writer( if record_summaries:
summary_writer_filepath) summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
if use_tpu: if use_tpu:
num_steps_per_iteration = 100 num_steps_per_iteration = 100
...@@ -603,7 +610,9 @@ def train_loop( ...@@ -603,7 +610,9 @@ def train_loop(
if num_steps_per_iteration > 1: if num_steps_per_iteration > 1:
for _ in tf.range(num_steps_per_iteration - 1): for _ in tf.range(num_steps_per_iteration - 1):
_sample_and_train(strategy, train_step_fn, data_iterator) # Following suggestion on yaqs/5402607292645376
with tf.name_scope(''):
_sample_and_train(strategy, train_step_fn, data_iterator)
return _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator)
......
...@@ -63,6 +63,11 @@ flags.DEFINE_integer( ...@@ -63,6 +63,11 @@ flags.DEFINE_integer(
'num_workers', 1, 'When num_workers > 1, training uses ' 'num_workers', 1, 'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses ' 'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.') 'MirroredStrategy.')
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -101,7 +106,9 @@ def main(unused_argv): ...@@ -101,7 +106,9 @@ def main(unused_argv):
pipeline_config_path=FLAGS.pipeline_config_path, pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps, train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu) use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.app.run() tf.compat.v1.app.run()
...@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor( ...@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor.""" """Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses return self._network.num_hourglasses
def get_model(self): def get_sub_model(self, sub_model_type):
return self._network if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def hourglass_104(channel_means, channel_stds, bgr_ordering): def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
...@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor.""" """The number of feature outputs returned by the feature extractor."""
return 1 return 1
def get_model(self): def get_sub_model(self, sub_model_type):
return self._network if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering): def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
...@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path) self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs): def call(self, inputs):
"""Returns image features extracted by the backbone. """Returns image features extracted by the backbone.
...@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering): def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
"""The ResNet v2 101 feature extractor.""" """The ResNet v2 101 feature extractor."""
......
...@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path): def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path) self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs): def call(self, inputs):
"""Returns image features extracted by the backbone. """Returns image features extracted by the backbone.
...@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self): def out_stride(self):
return 4 return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 101 FPN feature extractor.""" """The ResNet v1 101 FPN feature extractor."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment