Unverified Commit df89d3e0 authored by BasiaFusinska's avatar BasiaFusinska Committed by GitHub
Browse files

Merged commit includes the following changes: (#8753)

318938278  by Andre Araujo:

    Loading pretrained ImageNet weights to initialize the ResNet backbone. Changed the defaults of the batch size and initial learning rate to increase convergence on the GLDv2 dataset. Made the evaluation batch size dynamic depending on the global batch size.

--
318911740  by Andre Araujo:

    Introduced additional shuffling of the TRAIN and VALIDATION datasets to ensure label variance across batches.

--
318908335  by Andre Araujo:

    Export model migration to TF2.

--
318489123  by Andre Araujo:

    Model exporting script for global feature trained with DELF codebase.
    Additionally, makes a small change to replace back_prop=False in the tf.while_loop call (see deprecation notice in https://www.tensorflow.org/api_docs/python/tf/while_loop

).

--
318401984  by Andre Araujo:

    Add attention visualization to DELF training script.

--
318168500  by Andre Araujo:

    Several small changes to DELF open-source training code:
    - Replace "make_dataset_iterator" call which was deprecated by a more recent suitable version.
    - Add image summary, allowing visualization of the augmented images during training
    - Normalize images before feeding them to the model

--
316888714  by Andre Araujo:

    - Removed unnecessary cast from feature_aggregation_extraction.py
    - Fixed clustering script

--

PiperOrigin-RevId: 318938278
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent 9725a407
...@@ -86,6 +86,9 @@ message DelfConfig { ...@@ -86,6 +86,9 @@ message DelfConfig {
// Path to DELF model. // Path to DELF model.
optional string model_path = 1; // Required. optional string model_path = 1; // Required.
// Whether model has been exported using TF version 2+.
optional bool is_tf2_exported = 10 [default = false];
// Image scales to be used. // Image scales to be used.
repeated float image_scales = 2; repeated float image_scales = 2;
......
...@@ -131,7 +131,7 @@ def main(argv): ...@@ -131,7 +131,7 @@ def main(argv):
delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder)) delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
delf_dataset = delf_dataset.shuffle(1000).batch( delf_dataset = delf_dataset.shuffle(1000).batch(
features_for_clustering.shape[0]) features_for_clustering.shape[0])
iterator = delf_dataset.make_initializable_iterator() iterator = tf.compat.v1.data.make_initializable_iterator(delf_dataset)
def _initializer_fn(sess): def _initializer_fn(sess):
"""Initialize dataset iterator, feed in the data.""" """Initialize dataset iterator, feed in the data."""
......
...@@ -102,7 +102,15 @@ def MakeExtractor(config): ...@@ -102,7 +102,15 @@ def MakeExtractor(config):
Returns: Returns:
Function that receives an image and returns features. Function that receives an image and returns features.
Raises:
ValueError: if config is invalid.
""" """
# Assert the configuration
if config.use_global_features and hasattr(
config, 'is_tf2_exported') and config.is_tf2_exported:
raise ValueError('use_global_features is incompatible with is_tf2_exported')
# Load model. # Load model.
model = tf.saved_model.load(config.model_path) model = tf.saved_model.load(config.model_path)
...@@ -178,6 +186,7 @@ def MakeExtractor(config): ...@@ -178,6 +186,7 @@ def MakeExtractor(config):
else: else:
global_pca_parameters['variances'] = None global_pca_parameters['variances'] = None
if not hasattr(config, 'is_tf2_exported') or not config.is_tf2_exported:
model = model.prune(feeds=feeds, fetches=fetches) model = model.prune(feeds=feeds, fetches=fetches)
def ExtractorFn(image, resize_factor=1.0): def ExtractorFn(image, resize_factor=1.0):
...@@ -197,7 +206,6 @@ def MakeExtractor(config): ...@@ -197,7 +206,6 @@ def MakeExtractor(config):
features (key 'local_features' mapping to a dict with keys 'locations', features (key 'local_features' mapping to a dict with keys 'locations',
'descriptors', 'scales', 'attention'). 'descriptors', 'scales', 'attention').
""" """
resized_image, scale_factors = ResizeImage( resized_image, scale_factors = ResizeImage(
image, config, resize_factor=resize_factor) image, config, resize_factor=resize_factor)
...@@ -224,8 +232,20 @@ def MakeExtractor(config): ...@@ -224,8 +232,20 @@ def MakeExtractor(config):
output = None output = None
if config.use_local_features: if config.use_local_features:
output = model(image_tensor, image_scales_tensor, score_threshold_tensor, if hasattr(config, 'is_tf2_exported') and config.is_tf2_exported:
max_feature_num_tensor) predict = model.signatures['serving_default']
output_dict = predict(
input_image=image_tensor,
input_scales=image_scales_tensor,
input_max_feature_num=max_feature_num_tensor,
input_abs_thres=score_threshold_tensor)
output = [
output_dict['boxes'], output_dict['features'],
output_dict['scales'], output_dict['scores']
]
else:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor)
else: else:
output = model(image_tensor, image_scales_tensor) output = model(image_tensor, image_scales_tensor)
......
...@@ -269,8 +269,7 @@ class ExtractAggregatedRepresentation(object): ...@@ -269,8 +269,7 @@ class ExtractAggregatedRepresentation(object):
axis=0), [num_assignments, 1]) - tf.gather( axis=0), [num_assignments, 1]) - tf.gather(
codebook, selected_visual_words[ind]) codebook, selected_visual_words[ind])
return ind + 1, tf.tensor_scatter_nd_add( return ind + 1, tf.tensor_scatter_nd_add(
vlad, tf.expand_dims(selected_visual_words[ind], axis=1), vlad, tf.expand_dims(selected_visual_words[ind], axis=1), diff)
tf.cast(diff, dtype=tf.float32))
ind_vlad = tf.constant(0, dtype=tf.int32) ind_vlad = tf.constant(0, dtype=tf.int32)
keep_going = lambda j, vlad: tf.less(j, num_features) keep_going = lambda j, vlad: tf.less(j, num_features)
...@@ -396,9 +395,7 @@ class ExtractAggregatedRepresentation(object): ...@@ -396,9 +395,7 @@ class ExtractAggregatedRepresentation(object):
visual_words = tf.reshape( visual_words = tf.reshape(
tf.where( tf.where(
tf.greater( tf.greater(per_centroid_norms, tf.sqrt(_NORM_SQUARED_TOLERANCE))),
per_centroid_norms,
tf.cast(tf.sqrt(_NORM_SQUARED_TOLERANCE), dtype=tf.float32))),
[-1]) [-1])
per_centroid_normalized_vector = tf.math.l2_normalize( per_centroid_normalized_vector = tf.math.l2_normalize(
......
...@@ -302,6 +302,21 @@ def _write_relabeling_rules(relabeling_rules): ...@@ -302,6 +302,21 @@ def _write_relabeling_rules(relabeling_rules):
csv_writer.writerow([new_label, old_label]) csv_writer.writerow([new_label, old_label])
def _shuffle_by_columns(np_array, random_state):
"""Shuffle the columns of a 2D numpy array.
Args:
np_array: array to shuffle.
random_state: numpy RandomState to be used for shuffling.
Returns:
The shuffled array.
"""
columns = np_array.shape[1]
columns_indices = np.arange(columns)
random_state.shuffle(columns_indices)
return np_array[:, columns_indices]
def _build_train_and_validation_splits(image_paths, file_ids, labels, def _build_train_and_validation_splits(image_paths, file_ids, labels,
validation_split_size, seed): validation_split_size, seed):
"""Create TRAIN and VALIDATION splits containg all labels in equal proportion. """Create TRAIN and VALIDATION splits containg all labels in equal proportion.
...@@ -353,19 +368,21 @@ def _build_train_and_validation_splits(image_paths, file_ids, labels, ...@@ -353,19 +368,21 @@ def _build_train_and_validation_splits(image_paths, file_ids, labels,
for label, indexes in image_attrs_idx_by_label.items(): for label, indexes in image_attrs_idx_by_label.items():
# Create the subset for the current label. # Create the subset for the current label.
image_attrs_label = image_attrs[:, indexes] image_attrs_label = image_attrs[:, indexes]
images_per_label = image_attrs_label.shape[1]
# Shuffle the current label subset. # Shuffle the current label subset.
columns_indices = np.arange(images_per_label) image_attrs_label = _shuffle_by_columns(image_attrs_label, rs)
rs.shuffle(columns_indices)
image_attrs_label = image_attrs_label[:, columns_indices]
# Split the current label subset into TRAIN and VALIDATION splits and add # Split the current label subset into TRAIN and VALIDATION splits and add
# each split to the list of all splits. # each split to the list of all splits.
images_per_label = image_attrs_label.shape[1]
cutoff_idx = max(1, int(validation_split_size * images_per_label)) cutoff_idx = max(1, int(validation_split_size * images_per_label))
splits[_VALIDATION_SPLIT].append(image_attrs_label[:, 0 : cutoff_idx]) splits[_VALIDATION_SPLIT].append(image_attrs_label[:, 0 : cutoff_idx])
splits[_TRAIN_SPLIT].append(image_attrs_label[:, cutoff_idx : ]) splits[_TRAIN_SPLIT].append(image_attrs_label[:, cutoff_idx : ])
validation_split = np.concatenate(splits[_VALIDATION_SPLIT], axis=1) # Concatenate all subsets of image attributes into TRAIN and VALIDATION splits
train_split = np.concatenate(splits[_TRAIN_SPLIT], axis=1) # and reshuffle them again to ensure variance of labels across batches.
validation_split = _shuffle_by_columns(
np.concatenate(splits[_VALIDATION_SPLIT], axis=1), rs)
train_split = _shuffle_by_columns(
np.concatenate(splits[_TRAIN_SPLIT], axis=1), rs)
# Unstack the image attribute arrays in the TRAIN and VALIDATION splits and # Unstack the image attribute arrays in the TRAIN and VALIDATION splits and
# convert them back to lists. Convert labels back to 'int' from 'str' # convert them back to lists. Convert labels back to 'int' from 'str'
......
...@@ -29,11 +29,7 @@ import tensorflow as tf ...@@ -29,11 +29,7 @@ import tensorflow as tf
class _GoogleLandmarksInfo(object): class _GoogleLandmarksInfo(object):
"""Metadata about the Google Landmarks dataset.""" """Metadata about the Google Landmarks dataset."""
num_classes = { num_classes = {'gld_v1': 14951, 'gld_v2': 203094, 'gld_v2_clean': 81313}
'gld_v1': 14951,
'gld_v2': 203094,
'gld_v2_clean': 81313
}
class _DataAugmentationParams(object): class _DataAugmentationParams(object):
...@@ -123,6 +119,8 @@ def _ParseFunction(example, name_to_features, image_size, augmentation): ...@@ -123,6 +119,8 @@ def _ParseFunction(example, name_to_features, image_size, augmentation):
# Parse to get image. # Parse to get image.
image = parsed_example['image/encoded'] image = parsed_example['image/encoded']
image = tf.io.decode_jpeg(image) image = tf.io.decode_jpeg(image)
image = NormalizeImages(
image, pixel_value_scale=128.0, pixel_value_offset=128.0)
if augmentation: if augmentation:
image = _ImageNetCrop(image) image = _ImageNetCrop(image)
else: else:
...@@ -130,6 +128,7 @@ def _ParseFunction(example, name_to_features, image_size, augmentation): ...@@ -130,6 +128,7 @@ def _ParseFunction(example, name_to_features, image_size, augmentation):
image.set_shape([image_size, image_size, 3]) image.set_shape([image_size, image_size, 3])
# Parse to get label. # Parse to get label.
label = parsed_example['image/class/label'] label = parsed_example['image/class/label']
return image, label return image, label
...@@ -162,6 +161,7 @@ def CreateDataset(file_pattern, ...@@ -162,6 +161,7 @@ def CreateDataset(file_pattern,
'image/width': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'image/width': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'image/channels': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'image/channels': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'image/format': tf.io.FixedLenFeature([], tf.string, default_value=''), 'image/format': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/id': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/filename': tf.io.FixedLenFeature([], tf.string, default_value=''), 'image/filename': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/encoded': tf.io.FixedLenFeature([], tf.string, default_value=''), 'image/encoded': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'image/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=0),
......
...@@ -132,10 +132,12 @@ class Delf(tf.keras.Model): ...@@ -132,10 +132,12 @@ class Delf(tf.keras.Model):
self.attn_classification.trainable_weights) self.attn_classification.trainable_weights)
def call(self, input_image, training=True): def call(self, input_image, training=True):
blocks = {'block3': None} blocks = {}
self.backbone(input_image, intermediates_dict=blocks, training=training)
features = blocks['block3'] self.backbone.build_call(
input_image, intermediates_dict=blocks, training=training)
features = blocks['block3'] # pytype: disable=key-error
_, probs, _ = self.attention(features, training=training) _, probs, _ = self.attention(features, training=training)
return probs, features return probs, features
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export global feature tensorflow inference model.
This model includes image pyramids for multi-scale processing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from delf.python.training.model import delf_model
from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_list(
'input_scales_list', None,
'Optional input image scales to use. If None (default), an input end-point '
'"input_scales" is added for the exported model. If not None, the '
'specified list of floats will be hard-coded as the desired input scales.')
flags.DEFINE_enum(
'multi_scale_pool_type', 'None', ['None', 'average', 'sum'],
"If 'None' (default), the model is exported with an output end-point "
"'global_descriptors', where the global descriptor for each scale is "
"returned separately. If not 'None', the global descriptor of each scale is"
' pooled and a 1D global descriptor is returned, with output end-point '
"'global_descriptor'.")
flags.DEFINE_boolean('normalize_global_descriptor', False,
'If True, L2-normalizes global descriptor.')
def _build_tensor_info(tensor_dict):
"""Replace the dict's value by the tensor info.
Args:
tensor_dict: A dictionary contains <string, tensor>.
Returns:
dict: New dictionary contains <string, tensor_info>.
"""
return {
k: tf.compat.v1.saved_model.utils.build_tensor_info(t)
for k, t in tensor_dict.items()
}
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError('Export_path already exists.')
with tf.Graph().as_default() as g, tf.compat.v1.Session(graph=g) as sess:
# Setup the model for extraction.
model = delf_model.Delf(block3_strides=False, name='DELF')
# Initial forward pass to build model.
images = tf.zeros((1, 321, 321, 3), dtype=tf.float32)
model(images)
# Setup the multiscale extraction.
input_image = tf.compat.v1.placeholder(
tf.uint8, shape=(None, None, 3), name='input_image')
if FLAGS.input_scales_list is None:
input_scales = tf.compat.v1.placeholder(
tf.float32, shape=[None], name='input_scales')
else:
input_scales = tf.constant([float(s) for s in FLAGS.input_scales_list],
dtype=tf.float32,
shape=[len(FLAGS.input_scales_list)],
name='input_scales')
extracted_features = export_model_utils.ExtractGlobalFeatures(
input_image,
input_scales,
lambda x: model.backbone(x, training=False),
multi_scale_pool_type=FLAGS.multi_scale_pool_type,
normalize_global_descriptor=FLAGS.normalize_global_descriptor)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
model.load_weights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
named_input_tensors = {'input_image': input_image}
if FLAGS.input_scales_list is None:
named_input_tensors['input_scales'] = input_scales
# Outputs to the exported model.
named_output_tensors = {}
if FLAGS.multi_scale_pool_type == 'None':
named_output_tensors['global_descriptors'] = tf.identity(
extracted_features, name='global_descriptors')
else:
named_output_tensors['global_descriptor'] = tf.identity(
extracted_features, name='global_descriptor')
# Export the model.
signature_def = (
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=_build_tensor_info(named_input_tensors),
outputs=_build_tensor_info(named_output_tensors)))
print('Exporting trained model to:', export_path)
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
init_op = None
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=init_op)
builder.save()
if __name__ == '__main__':
app.run(main)
...@@ -42,67 +42,39 @@ flags.DEFINE_boolean('block3_strides', False, ...@@ -42,67 +42,39 @@ flags.DEFINE_boolean('block3_strides', False,
flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.') flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.')
def _build_tensor_info(tensor_dict): class _ExtractModule(tf.Module):
"""Replace the dict's value by the tensor info. """Helper module to build and save DELF model."""
Args: def __init__(self, block3_strides, iou):
tensor_dict: A dictionary contains <string, tensor>. """Initialization of DELF model.
Returns: Args:
dict: New dictionary contains <string, tensor_info>. block3_strides: bool, whether to add strides to the output of block3.
iou: IOU for non-max suppression.
""" """
return { self._stride_factor = 2.0 if block3_strides else 1.0
k: tf.compat.v1.saved_model.utils.build_tensor_info(t) self._iou = iou
for k, t in tensor_dict.items()
}
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError('Export_path already exists.')
with tf.Graph().as_default() as g, tf.compat.v1.Session(graph=g) as sess:
# Setup the DELF model for extraction. # Setup the DELF model for extraction.
model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF') self._model = delf_model.Delf(
block3_strides=block3_strides, name='DELF')
# Initial forward pass to build model. def LoadWeights(self, checkpoint_path):
images = tf.zeros((1, 321, 321, 3), dtype=tf.float32) self._model.load_weights(checkpoint_path)
model(images)
stride_factor = 2.0 if FLAGS.block3_strides else 1.0 @tf.function(input_signature=[
tf.TensorSpec(shape=[None, None, 3], dtype=tf.uint8, name='input_image'),
# Setup the multiscale keypoint extraction. tf.TensorSpec(shape=[None], dtype=tf.float32, name='input_scales'),
input_image = tf.compat.v1.placeholder( tf.TensorSpec(shape=(), dtype=tf.int32, name='input_max_feature_num'),
tf.uint8, shape=(None, None, 3), name='input_image') tf.TensorSpec(shape=(), dtype=tf.float32, name='input_abs_thres')
input_abs_thres = tf.compat.v1.placeholder( ])
tf.float32, shape=(), name='input_abs_thres') def ExtractFeatures(self, input_image, input_scales, input_max_feature_num,
input_scales = tf.compat.v1.placeholder( input_abs_thres):
tf.float32, shape=[None], name='input_scales')
input_max_feature_num = tf.compat.v1.placeholder(
tf.int32, shape=(), name='input_max_feature_num')
extracted_features = export_model_utils.ExtractLocalFeatures( extracted_features = export_model_utils.ExtractLocalFeatures(
input_image, input_scales, input_max_feature_num, input_abs_thres, input_image, input_scales, input_max_feature_num, input_abs_thres,
FLAGS.iou, lambda x: model(x, training=False), stride_factor) self._iou, lambda x: self._model(x, training=False),
self._stride_factor)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
model.load_weights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
named_input_tensors = {
'input_image': input_image,
'input_scales': input_scales,
'input_abs_thres': input_abs_thres,
'input_max_feature_num': input_max_feature_num,
}
# Outputs to the exported model.
named_output_tensors = {} named_output_tensors = {}
named_output_tensors['boxes'] = tf.identity( named_output_tensors['boxes'] = tf.identity(
extracted_features[0], name='boxes') extracted_features[0], name='boxes')
...@@ -112,25 +84,27 @@ def main(argv): ...@@ -112,25 +84,27 @@ def main(argv):
extracted_features[2], name='scales') extracted_features[2], name='scales')
named_output_tensors['scores'] = tf.identity( named_output_tensors['scores'] = tf.identity(
extracted_features[3], name='scores') extracted_features[3], name='scores')
return named_output_tensors
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError(f'Export_path {export_path} already exists. Please '
'specify a different path or delete the existing one.')
module = _ExtractModule(FLAGS.block3_strides, FLAGS.iou)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
module.LoadWeights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
# Export the model. # Save the module
signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def( tf.saved_model.save(module, export_path)
inputs=_build_tensor_info(named_input_tensors),
outputs=_build_tensor_info(named_output_tensors))
print('Exporting trained model to:', export_path)
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
init_op = None
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=init_op)
builder.save()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -142,7 +142,9 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou, ...@@ -142,7 +142,9 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
keep_going = lambda j, b, f, scales, scores: tf.less(j, num_scales) keep_going = lambda j, b, f, scales, scores: tf.less(j, num_scales)
(_, output_boxes, output_features, output_scales, (_, output_boxes, output_features, output_scales,
output_scores) = tf.while_loop( output_scores) = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
cond=keep_going, cond=keep_going,
body=_ProcessSingleScale, body=_ProcessSingleScale,
loop_vars=[ loop_vars=[
...@@ -154,8 +156,7 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou, ...@@ -154,8 +156,7 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
tf.TensorShape([None, feature_depth]), tf.TensorShape([None, feature_depth]),
tf.TensorShape([None]), tf.TensorShape([None]),
tf.TensorShape([None]) tf.TensorShape([None])
], ]))
back_prop=False)
feature_boxes = box_list.BoxList(output_boxes) feature_boxes = box_list.BoxList(output_boxes)
feature_boxes.add_field('features', output_features) feature_boxes.add_field('features', output_features)
...@@ -169,3 +170,109 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou, ...@@ -169,3 +170,109 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
return final_boxes.get(), final_boxes.get_field( return final_boxes.get(), final_boxes.get_field(
'features'), final_boxes.get_field('scales'), tf.expand_dims( 'features'), final_boxes.get_field('scales'), tf.expand_dims(
final_boxes.get_field('scores'), 1) final_boxes.get_field('scores'), 1)
def ExtractGlobalFeatures(image,
image_scales,
model_fn,
multi_scale_pool_type='None',
normalize_global_descriptor=False):
"""Extract global features for input image.
Args:
image: image tensor of type tf.uint8 with shape [h, w, channels].
image_scales: 1D float tensor which contains float scales used for image
pyramid construction.
model_fn: model function. Follows the signature:
* Args:
* `images`: Image tensor which is re-scaled.
* Returns:
* `global_descriptors`: Global descriptors for input images.
multi_scale_pool_type: If set, the global descriptor of each scale is pooled
and a 1D global descriptor is returned.
normalize_global_descriptor: If True, output global descriptors are
L2-normalized.
Returns:
global_descriptors: If `multi_scale_pool_type` is 'None', returns a [S, D]
float tensor. S is the number of scales, and D the global descriptor
dimensionality. Each D-dimensional entry is a global descriptor, which may
be L2-normalized depending on `normalize_global_descriptor`. If
`multi_scale_pool_type` is not 'None', returns a [D] float tensor with the
pooled global descriptor.
"""
original_image_shape_float = tf.gather(
tf.dtypes.cast(tf.shape(image), tf.float32), [0, 1])
image_tensor = gld.NormalizeImages(
image, pixel_value_offset=128.0, pixel_value_scale=128.0)
image_tensor = tf.expand_dims(image_tensor, 0, name='image/expand_dims')
def _ProcessSingleScale(scale_index, global_descriptors=None):
"""Resizes the image and runs feature extraction.
This function will be passed into tf.while_loop() and be called
repeatedly. We get the current scale by image_scales[scale_index], and
run image resizing / feature extraction. In the end, we concat the
previous global descriptors with current descriptor as the output.
Args:
scale_index: A valid index in image_scales.
global_descriptors: Global descriptor tensor with the shape of [S, D]. If
None, no previous global descriptors are used, and the output will be of
shape [1, D].
Returns:
scale_index: The next scale index for processing.
global_descriptors: A concatenated global descriptor tensor with the shape
of [S+1, D].
"""
scale = tf.gather(image_scales, scale_index)
new_image_size = tf.dtypes.cast(
tf.round(original_image_shape_float * scale), tf.int32)
resized_image = tf.image.resize(image_tensor, new_image_size)
global_descriptor = model_fn(resized_image)
if global_descriptors is None:
global_descriptors = global_descriptor
else:
global_descriptors = tf.concat([global_descriptors, global_descriptor], 0)
return scale_index + 1, global_descriptors
# Process the first scale separately, the following scales will reuse the
# graph variables.
(_, output_global) = _ProcessSingleScale(0)
i = tf.constant(1, dtype=tf.int32)
num_scales = tf.shape(image_scales)[0]
keep_going = lambda j, g: tf.less(j, num_scales)
(_, output_global) = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
cond=keep_going,
body=_ProcessSingleScale,
loop_vars=[i, output_global],
shape_invariants=[i.get_shape(),
tf.TensorShape([None, None])]))
normalization_axis = 1
if multi_scale_pool_type == 'average':
output_global = tf.reduce_mean(
output_global,
axis=0,
keepdims=False,
name='multi_scale_average_pooling')
normalization_axis = 0
elif multi_scale_pool_type == 'sum':
output_global = tf.reduce_sum(
output_global, axis=0, keepdims=False, name='multi_scale_sum_pooling')
normalization_axis = 0
if normalize_global_descriptor:
output_global = tf.nn.l2_normalize(
output_global, axis=normalization_axis, name='l2_normalization')
return output_global
...@@ -22,9 +22,14 @@ from __future__ import division ...@@ -22,9 +22,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
import os
import tempfile
from absl import logging
import h5py
import tensorflow as tf import tensorflow as tf
layers = tf.keras.layers layers = tf.keras.layers
...@@ -284,8 +289,8 @@ class ResNet50(tf.keras.Model): ...@@ -284,8 +289,8 @@ class ResNet50(tf.keras.Model):
else: else:
self.global_pooling = None self.global_pooling = None
def call(self, inputs, training=True, intermediates_dict=None): def build_call(self, inputs, training=True, intermediates_dict=None):
"""Call the ResNet50 model. """Building the ResNet50 model.
Args: Args:
inputs: Images to compute features for. inputs: Images to compute features for.
...@@ -356,3 +361,79 @@ class ResNet50(tf.keras.Model): ...@@ -356,3 +361,79 @@ class ResNet50(tf.keras.Model):
return self.global_pooling(x) return self.global_pooling(x)
else: else:
return x return x
def call(self, inputs, training=True, intermediates_dict=None):
"""Call the ResNet50 model.
Args:
inputs: Images to compute features for.
training: Whether model is in training phase.
intermediates_dict: `None` or dictionary. If not None, accumulate feature
maps from intermediate blocks into the dictionary. ""
Returns:
Tensor with featuremap.
"""
return self.build_call(inputs, training, intermediates_dict)
def restore_weights(self, filepath):
"""Load pretrained weights.
This function loads a .h5 file from the filepath with saved model weights
and assigns them to the model.
Args:
filepath: String, path to the .h5 file
Raises:
ValueError: if the file referenced by `filepath` does not exist.
"""
if not tf.io.gfile.exists(filepath):
raise ValueError('Unable to load weights from %s. You must provide a'
'valid file.' % (filepath))
# Create a local copy of the weights file for h5py to be able to read it.
local_filename = os.path.basename(filepath)
tmp_filename = os.path.join(tempfile.gettempdir(), local_filename)
tf.io.gfile.copy(filepath, tmp_filename, overwrite=True)
# Load the content of the weights file.
f = h5py.File(tmp_filename, mode='r')
saved_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
try:
# Iterate through all the layers assuming the max `depth` is 2.
for layer in self.layers:
if hasattr(layer, 'layers'):
for inlayer in layer.layers:
# Make sure the weights are in the saved model, and that we are in
# the innermost layer.
if inlayer.name not in saved_layer_names:
raise ValueError('Layer %s absent from the pretrained weights.'
'Unable to load its weights.' % (inlayer.name))
if hasattr(inlayer, 'layers'):
raise ValueError('Layer %s is not a depth 2 layer. Unable to load'
'its weights.' % (inlayer.name))
# Assign the weights in the current layer.
g = f[inlayer.name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = [g[weight_name] for weight_name in weight_names]
print('Setting the weights for layer %s' % (inlayer.name))
inlayer.set_weights(weight_values)
finally:
# Clean up the temporary file.
tf.io.gfile.remove(tmp_filename)
def log_weights(self):
"""Log backbone weights."""
logging.info('Logging backbone weights')
logging.info('------------------------')
for layer in self.layers:
if hasattr(layer, 'layers'):
for inlayer in layer.layers:
logging.info('Weights for layer: %s, inlayer % s', layer.name,
inlayer.name)
weights = inlayer.get_weights()
logging.info(weights)
else:
logging.info('Layer %s does not have inner layers.',
layer.name)
...@@ -43,17 +43,20 @@ flags.DEFINE_string('train_file_pattern', '/tmp/data/train*', ...@@ -43,17 +43,20 @@ flags.DEFINE_string('train_file_pattern', '/tmp/data/train*',
'File pattern of training dataset files.') 'File pattern of training dataset files.')
flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*', flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*',
'File pattern of validation dataset files.') 'File pattern of validation dataset files.')
flags.DEFINE_enum('dataset_version', 'gld_v1', flags.DEFINE_enum(
['gld_v1', 'gld_v2', 'gld_v2_clean'], 'dataset_version', 'gld_v1', ['gld_v1', 'gld_v2', 'gld_v2_clean'],
'Google Landmarks dataset version, used to determine the' 'Google Landmarks dataset version, used to determine the'
'number of classes.') 'number of classes.')
flags.DEFINE_integer('seed', 0, 'Seed to training dataset.') flags.DEFINE_integer('seed', 0, 'Seed to training dataset.')
flags.DEFINE_float('initial_lr', 0.001, 'Initial learning rate.') flags.DEFINE_float('initial_lr', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('batch_size', 32, 'Global batch size.') flags.DEFINE_integer('batch_size', 32, 'Global batch size.')
flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.') flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.')
flags.DEFINE_boolean('block3_strides', False, 'Whether to use block3_strides.') flags.DEFINE_boolean('block3_strides', True, 'Whether to use block3_strides.')
flags.DEFINE_boolean('use_augmentation', True, flags.DEFINE_boolean('use_augmentation', True,
'Whether to use ImageNet style augmentation.') 'Whether to use ImageNet style augmentation.')
flags.DEFINE_string(
'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
def _record_accuracy(metric, logits, labels): def _record_accuracy(metric, logits, labels):
...@@ -64,6 +67,10 @@ def _record_accuracy(metric, logits, labels): ...@@ -64,6 +67,10 @@ def _record_accuracy(metric, logits, labels):
def _attention_summaries(scores, global_step): def _attention_summaries(scores, global_step):
"""Record statistics of the attention score.""" """Record statistics of the attention score."""
tf.summary.image(
'batch_attention',
scores / tf.reduce_max(scores + 1e-3),
step=global_step)
tf.summary.scalar('attention/max', tf.reduce_max(scores), step=global_step) tf.summary.scalar('attention/max', tf.reduce_max(scores), step=global_step)
tf.summary.scalar('attention/min', tf.reduce_min(scores), step=global_step) tf.summary.scalar('attention/min', tf.reduce_min(scores), step=global_step)
tf.summary.scalar('attention/mean', tf.reduce_mean(scores), step=global_step) tf.summary.scalar('attention/mean', tf.reduce_mean(scores), step=global_step)
...@@ -124,7 +131,7 @@ def main(argv): ...@@ -124,7 +131,7 @@ def main(argv):
max_iters = FLAGS.max_iters max_iters = FLAGS.max_iters
global_batch_size = FLAGS.batch_size global_batch_size = FLAGS.batch_size
image_size = 321 image_size = 321
num_eval = 1000 num_eval_batches = int(50000 / global_batch_size)
report_interval = 100 report_interval = 100
eval_interval = 1000 eval_interval = 1000
save_interval = 20000 save_interval = 20000
...@@ -134,9 +141,10 @@ def main(argv): ...@@ -134,9 +141,10 @@ def main(argv):
clip_val = tf.constant(10.0) clip_val = tf.constant(10.0)
if FLAGS.debug: if FLAGS.debug:
tf.config.run_functions_eagerly(True)
global_batch_size = 4 global_batch_size = 4
max_iters = 4 max_iters = 100
num_eval = 1 num_eval_batches = 1
save_interval = 1 save_interval = 1
report_interval = 1 report_interval = 1
...@@ -159,11 +167,12 @@ def main(argv): ...@@ -159,11 +167,12 @@ def main(argv):
augmentation=False, augmentation=False,
seed=FLAGS.seed) seed=FLAGS.seed)
train_iterator = strategy.make_dataset_iterator(train_dataset) train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
validation_iterator = strategy.make_dataset_iterator(validation_dataset) validation_dist_dataset = strategy.experimental_distribute_dataset(
validation_dataset)
train_iterator.initialize() train_iter = iter(train_dist_dataset)
validation_iterator.initialize() validation_iter = iter(validation_dist_dataset)
# Create a checkpoint directory to store the checkpoints. # Create a checkpoint directory to store the checkpoints.
checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt') checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')
...@@ -219,11 +228,14 @@ def main(argv): ...@@ -219,11 +228,14 @@ def main(argv):
labels = tf.clip_by_value(labels, 0, model.num_classes) labels = tf.clip_by_value(labels, 0, model.num_classes)
global_step = optimizer.iterations global_step = optimizer.iterations
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar( tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step) 'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar( tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step) 'image_range/min', tf.reduce_min(images), step=global_step)
# TODO(andrearaujo): we should try to unify the backprop into a single
# function, instead of applying once to descriptor then to attention.
def _backprop_loss(tape, loss, weights): def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients. """Backpropogate losses using clipped gradients.
...@@ -344,12 +356,25 @@ def main(argv): ...@@ -344,12 +356,25 @@ def main(argv):
with tf.summary.record_if( with tf.summary.record_if(
tf.math.equal(0, optimizer.iterations % report_interval)): tf.math.equal(0, optimizer.iterations % report_interval)):
# TODO(dananghel): try to load pretrained weights at backbone creation.
# Load pretrained weights for ResNet50 trained on ImageNet.
if FLAGS.imagenet_checkpoint is not None:
logging.info('Attempting to load ImageNet pretrained weights.')
input_batch = next(train_iter)
_, _ = distributed_train_step(input_batch)
model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
logging.info('Done.')
else:
logging.info('Skip loading ImageNet pretrained weights.')
if FLAGS.debug:
model.backbone.log_weights()
global_step_value = optimizer.iterations.numpy() global_step_value = optimizer.iterations.numpy()
while global_step_value < max_iters: while global_step_value < max_iters:
# input_batch : images(b, h, w, c), labels(b,). # input_batch : images(b, h, w, c), labels(b,).
try: try:
input_batch = train_iterator.get_next() input_batch = next(train_iter)
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
# Break if we run out of data in the dataset. # Break if we run out of data in the dataset.
logging.info('Stopping training at global step %d, no more data', logging.info('Stopping training at global step %d, no more data',
...@@ -392,9 +417,9 @@ def main(argv): ...@@ -392,9 +417,9 @@ def main(argv):
# Validate once in {eval_interval*n, n \in N} steps. # Validate once in {eval_interval*n, n \in N} steps.
if global_step_value % eval_interval == 0: if global_step_value % eval_interval == 0:
for i in range(num_eval): for i in range(num_eval_batches):
try: try:
validation_batch = validation_iterator.get_next() validation_batch = next(validation_iter)
desc_validation_result, attn_validation_result = ( desc_validation_result, attn_validation_result = (
distributed_validation_step(validation_batch)) distributed_validation_step(validation_batch))
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
...@@ -416,13 +441,17 @@ def main(argv): ...@@ -416,13 +441,17 @@ def main(argv):
print(' : attn:', attn_validation_result.numpy()) print(' : attn:', attn_validation_result.numpy())
# Save checkpoint once (each save_interval*n, n \in N) steps. # Save checkpoint once (each save_interval*n, n \in N) steps.
# TODO(andrearaujo): save only in one of the two ways. They are
# identical, the only difference is that the manager adds some extra
# prefixes and variables (eg, optimizer variables).
if global_step_value % save_interval == 0: if global_step_value % save_interval == 0:
save_path = manager.save() save_path = manager.save()
logging.info('Saved({global_step_value}) at %s', save_path) logging.info('Saved (%d) at %s', global_step_value, save_path)
file_path = '%s/delf_weights' % FLAGS.logdir file_path = '%s/delf_weights' % FLAGS.logdir
model.save_weights(file_path, save_format='tf') model.save_weights(file_path, save_format='tf')
logging.info('Saved weights({global_step_value}) at %s', file_path) logging.info('Saved weights (%d) at %s', global_step_value,
file_path)
# Reset metrics for next step. # Reset metrics for next step.
desc_train_accuracy.reset_states() desc_train_accuracy.reset_states()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment