Commit 7359586f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents c594cecf a78b05b9
......@@ -27,7 +27,9 @@ import functools
from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import context_rcnn_lib
from object_detection.meta_architectures import context_rcnn_lib_tf2
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import tf_version
class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
......@@ -264,11 +266,17 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
return_raw_detections_during_predict),
output_final_box_features=output_final_box_features)
self._context_feature_extract_fn = functools.partial(
context_rcnn_lib.compute_box_context_attention,
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
if tf_version.is_tf1():
self._context_feature_extract_fn = functools.partial(
context_rcnn_lib.compute_box_context_attention,
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
else:
self._context_feature_extract_fn = context_rcnn_lib_tf2.AttentionBlock(
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
@staticmethod
def get_side_inputs(features):
......@@ -323,6 +331,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
Returns:
A float32 Tensor with shape [K, new_height, new_width, depth].
"""
box_features = self._crop_and_resize_fn(
[features_to_crop], proposal_boxes_normalized, None,
[self._initial_crop_size, self._initial_crop_size])
......
......@@ -109,7 +109,6 @@ class FakeFasterRCNNKerasFeatureExtractor(
])
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
def _get_model(self, box_predictor, **common_kwargs):
......@@ -440,15 +439,16 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
masks_are_class_agnostic=masks_are_class_agnostic,
share_box_across_classes=share_box_across_classes), **common_kwargs)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
@mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib')
def test_prediction_mock(self, mock_context_rcnn_lib):
"""Mocks the context_rcnn_lib module to test the prediction.
def test_prediction_mock_tf1(self, mock_context_rcnn_lib_v1):
"""Mocks the context_rcnn_lib_v1 module to test the prediction.
Using mock object so that we can ensure compute_box_context_attention is
called in side the prediction function.
Args:
mock_context_rcnn_lib: mock module for the context_rcnn_lib.
mock_context_rcnn_lib_v1: mock module for the context_rcnn_lib_v1.
"""
model = self._build_model(
is_training=False,
......@@ -457,7 +457,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
num_classes=42)
mock_tensor = tf.ones([2, 8, 3, 3, 3], tf.float32)
mock_context_rcnn_lib.compute_box_context_attention.return_value = mock_tensor
mock_context_rcnn_lib_v1.compute_box_context_attention.return_value = mock_tensor
inputs_shape = (2, 20, 20, 3)
inputs = tf.cast(
tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32),
......@@ -479,7 +479,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
side_inputs = model.get_side_inputs(features)
_ = model.predict(preprocessed_inputs, true_image_shapes, **side_inputs)
mock_context_rcnn_lib.compute_box_context_attention.assert_called_once()
mock_context_rcnn_lib_v1.compute_box_context_attention.assert_called_once()
@parameterized.named_parameters(
{'testcase_name': 'static_shapes', 'static_shapes': True},
......@@ -518,7 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
}
side_inputs = model.get_side_inputs(features)
prediction_dict = model.predict(preprocessed_inputs, true_image_shapes,
**side_inputs)
return (prediction_dict['rpn_box_predictor_features'],
......
......@@ -23,6 +23,7 @@ import os
import time
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util
from object_detection import inputs
......@@ -117,7 +118,8 @@ def _compute_losses_and_predictions_dicts(
prediction_dict = model.predict(
preprocessed_images,
features[fields.InputDataFields.true_image_shape])
features[fields.InputDataFields.true_image_shape],
**model.get_side_inputs(features))
prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
losses_dict = model.loss(
......@@ -413,8 +415,9 @@ def train_loop(
train_steps=None,
use_tpu=False,
save_final_config=False,
checkpoint_every_n=5000,
checkpoint_every_n=1000,
checkpoint_max_to_keep=7,
record_summaries=True,
**kwargs):
"""Trains a model using eager + functions.
......@@ -444,6 +447,7 @@ def train_loop(
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
......@@ -530,8 +534,11 @@ def train_loop(
# is the chief.
summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train'))
summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
if record_summaries:
summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
if use_tpu:
num_steps_per_iteration = 100
......@@ -603,7 +610,9 @@ def train_loop(
if num_steps_per_iteration > 1:
for _ in tf.range(num_steps_per_iteration - 1):
_sample_and_train(strategy, train_step_fn, data_iterator)
# Following suggestion on yaqs/5402607292645376
with tf.name_scope(''):
_sample_and_train(strategy, train_step_fn, data_iterator)
return _sample_and_train(strategy, train_step_fn, data_iterator)
......
......@@ -63,6 +63,11 @@ flags.DEFINE_integer(
'num_workers', 1, 'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.')
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
FLAGS = flags.FLAGS
......@@ -101,7 +106,9 @@ def main(unused_argv):
pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu)
use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
if __name__ == '__main__':
tf.compat.v1.app.run()
......@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return self._network.num_hourglasses
def get_model(self):
return self._network
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for Hourglass.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def hourglass_104(channel_means, channel_stds, bgr_ordering):
......
......@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return 1
def get_model(self):
return self._network
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
supported_types = ['detection']
raise ValueError(
('Sub model {} is not defined for MobileNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
......
......@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs):
"""Returns image features extracted by the backbone.
......@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet.'.format(sub_model_type) +
'Supported types are {}.'.format(supported_types)))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
"""The ResNet v2 101 feature extractor."""
......
......@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def load_feature_extractor_weights(self, path):
self._base_model.load_weights(path)
def get_base_model(self):
"""Get base resnet model for inspection and testing."""
return self._base_model
def call(self, inputs):
"""Returns image features extracted by the backbone.
......@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def out_stride(self):
return 4
def get_sub_model(self, sub_model_type):
if sub_model_type == 'classification':
return self._base_model
else:
supported_types = ['classification']
raise ValueError(
('Sub model {} is not defined for ResNet FPN.'.format(sub_model_type)
+ 'Supported types are {}.'.format(supported_types)))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
"""The ResNet v1 101 FPN feature extractor."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment