Unverified Commit df89d3e0 authored by BasiaFusinska's avatar BasiaFusinska Committed by GitHub
Browse files

Merged commit includes the following changes: (#8753)

318938278  by Andre Araujo:

    Loading pretrained ImageNet weights to initialize the ResNet backbone. Changed the defaults of the batch size and initial learning rate to increase convergence on the GLDv2 dataset. Made the evaluation batch size dynamic depending on the global batch size.

--
318911740  by Andre Araujo:

    Introduced additional shuffling of the TRAIN and VALIDATION datasets to ensure label variance across batches.

--
318908335  by Andre Araujo:

    Export model migration to TF2.

--
318489123  by Andre Araujo:

    Model exporting script for global feature trained with DELF codebase.
    Additionally, makes a small change to replace back_prop=False in the tf.while_loop call (see deprecation notice in https://www.tensorflow.org/api_docs/python/tf/while_loop

).

--
318401984  by Andre Araujo:

    Add attention visualization to DELF training script.

--
318168500  by Andre Araujo:

    Several small changes to DELF open-source training code:
    - Replace "make_dataset_iterator" call which was deprecated by a more recent suitable version.
    - Add image summary, allowing visualization of the augmented images during training
    - Normalize images before feeding them to the model

--
316888714  by Andre Araujo:

    - Removed unnecessary cast from feature_aggregation_extraction.py
    - Fixed clustering script

--

PiperOrigin-RevId: 318938278
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent 9725a407
......@@ -86,6 +86,9 @@ message DelfConfig {
// Path to DELF model.
optional string model_path = 1; // Required.
// Whether model has been exported using TF version 2+.
optional bool is_tf2_exported = 10 [default = false];
// Image scales to be used.
repeated float image_scales = 2;
......
......@@ -131,7 +131,7 @@ def main(argv):
delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
delf_dataset = delf_dataset.shuffle(1000).batch(
features_for_clustering.shape[0])
iterator = delf_dataset.make_initializable_iterator()
iterator = tf.compat.v1.data.make_initializable_iterator(delf_dataset)
def _initializer_fn(sess):
"""Initialize dataset iterator, feed in the data."""
......
......@@ -102,7 +102,15 @@ def MakeExtractor(config):
Returns:
Function that receives an image and returns features.
Raises:
ValueError: if config is invalid.
"""
# Assert the configuration
if config.use_global_features and hasattr(
config, 'is_tf2_exported') and config.is_tf2_exported:
raise ValueError('use_global_features is incompatible with is_tf2_exported')
# Load model.
model = tf.saved_model.load(config.model_path)
......@@ -178,7 +186,8 @@ def MakeExtractor(config):
else:
global_pca_parameters['variances'] = None
model = model.prune(feeds=feeds, fetches=fetches)
if not hasattr(config, 'is_tf2_exported') or not config.is_tf2_exported:
model = model.prune(feeds=feeds, fetches=fetches)
def ExtractorFn(image, resize_factor=1.0):
"""Receives an image and returns DELF global and/or local features.
......@@ -197,7 +206,6 @@ def MakeExtractor(config):
features (key 'local_features' mapping to a dict with keys 'locations',
'descriptors', 'scales', 'attention').
"""
resized_image, scale_factors = ResizeImage(
image, config, resize_factor=resize_factor)
......@@ -224,8 +232,20 @@ def MakeExtractor(config):
output = None
if config.use_local_features:
output = model(image_tensor, image_scales_tensor, score_threshold_tensor,
max_feature_num_tensor)
if hasattr(config, 'is_tf2_exported') and config.is_tf2_exported:
predict = model.signatures['serving_default']
output_dict = predict(
input_image=image_tensor,
input_scales=image_scales_tensor,
input_max_feature_num=max_feature_num_tensor,
input_abs_thres=score_threshold_tensor)
output = [
output_dict['boxes'], output_dict['features'],
output_dict['scales'], output_dict['scores']
]
else:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor)
else:
output = model(image_tensor, image_scales_tensor)
......
......@@ -269,8 +269,7 @@ class ExtractAggregatedRepresentation(object):
axis=0), [num_assignments, 1]) - tf.gather(
codebook, selected_visual_words[ind])
return ind + 1, tf.tensor_scatter_nd_add(
vlad, tf.expand_dims(selected_visual_words[ind], axis=1),
tf.cast(diff, dtype=tf.float32))
vlad, tf.expand_dims(selected_visual_words[ind], axis=1), diff)
ind_vlad = tf.constant(0, dtype=tf.int32)
keep_going = lambda j, vlad: tf.less(j, num_features)
......@@ -396,9 +395,7 @@ class ExtractAggregatedRepresentation(object):
visual_words = tf.reshape(
tf.where(
tf.greater(
per_centroid_norms,
tf.cast(tf.sqrt(_NORM_SQUARED_TOLERANCE), dtype=tf.float32))),
tf.greater(per_centroid_norms, tf.sqrt(_NORM_SQUARED_TOLERANCE))),
[-1])
per_centroid_normalized_vector = tf.math.l2_normalize(
......
......@@ -302,6 +302,21 @@ def _write_relabeling_rules(relabeling_rules):
csv_writer.writerow([new_label, old_label])
def _shuffle_by_columns(np_array, random_state):
"""Shuffle the columns of a 2D numpy array.
Args:
np_array: array to shuffle.
random_state: numpy RandomState to be used for shuffling.
Returns:
The shuffled array.
"""
columns = np_array.shape[1]
columns_indices = np.arange(columns)
random_state.shuffle(columns_indices)
return np_array[:, columns_indices]
def _build_train_and_validation_splits(image_paths, file_ids, labels,
validation_split_size, seed):
"""Create TRAIN and VALIDATION splits containg all labels in equal proportion.
......@@ -353,19 +368,21 @@ def _build_train_and_validation_splits(image_paths, file_ids, labels,
for label, indexes in image_attrs_idx_by_label.items():
# Create the subset for the current label.
image_attrs_label = image_attrs[:, indexes]
images_per_label = image_attrs_label.shape[1]
# Shuffle the current label subset.
columns_indices = np.arange(images_per_label)
rs.shuffle(columns_indices)
image_attrs_label = image_attrs_label[:, columns_indices]
image_attrs_label = _shuffle_by_columns(image_attrs_label, rs)
# Split the current label subset into TRAIN and VALIDATION splits and add
# each split to the list of all splits.
images_per_label = image_attrs_label.shape[1]
cutoff_idx = max(1, int(validation_split_size * images_per_label))
splits[_VALIDATION_SPLIT].append(image_attrs_label[:, 0 : cutoff_idx])
splits[_TRAIN_SPLIT].append(image_attrs_label[:, cutoff_idx : ])
validation_split = np.concatenate(splits[_VALIDATION_SPLIT], axis=1)
train_split = np.concatenate(splits[_TRAIN_SPLIT], axis=1)
# Concatenate all subsets of image attributes into TRAIN and VALIDATION splits
# and reshuffle them again to ensure variance of labels across batches.
validation_split = _shuffle_by_columns(
np.concatenate(splits[_VALIDATION_SPLIT], axis=1), rs)
train_split = _shuffle_by_columns(
np.concatenate(splits[_TRAIN_SPLIT], axis=1), rs)
# Unstack the image attribute arrays in the TRAIN and VALIDATION splits and
# convert them back to lists. Convert labels back to 'int' from 'str'
......
......@@ -29,11 +29,7 @@ import tensorflow as tf
class _GoogleLandmarksInfo(object):
"""Metadata about the Google Landmarks dataset."""
num_classes = {
'gld_v1': 14951,
'gld_v2': 203094,
'gld_v2_clean': 81313
}
num_classes = {'gld_v1': 14951, 'gld_v2': 203094, 'gld_v2_clean': 81313}
class _DataAugmentationParams(object):
......@@ -123,6 +119,8 @@ def _ParseFunction(example, name_to_features, image_size, augmentation):
# Parse to get image.
image = parsed_example['image/encoded']
image = tf.io.decode_jpeg(image)
image = NormalizeImages(
image, pixel_value_scale=128.0, pixel_value_offset=128.0)
if augmentation:
image = _ImageNetCrop(image)
else:
......@@ -130,6 +128,7 @@ def _ParseFunction(example, name_to_features, image_size, augmentation):
image.set_shape([image_size, image_size, 3])
# Parse to get label.
label = parsed_example['image/class/label']
return image, label
......@@ -162,6 +161,7 @@ def CreateDataset(file_pattern,
'image/width': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'image/channels': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'image/format': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/id': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/filename': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/encoded': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=0),
......
......@@ -132,10 +132,12 @@ class Delf(tf.keras.Model):
self.attn_classification.trainable_weights)
def call(self, input_image, training=True):
blocks = {'block3': None}
self.backbone(input_image, intermediates_dict=blocks, training=training)
blocks = {}
features = blocks['block3']
self.backbone.build_call(
input_image, intermediates_dict=blocks, training=training)
features = blocks['block3'] # pytype: disable=key-error
_, probs, _ = self.attention(features, training=training)
return probs, features
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export global feature tensorflow inference model.
This model includes image pyramids for multi-scale processing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from delf.python.training.model import delf_model
from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_list(
'input_scales_list', None,
'Optional input image scales to use. If None (default), an input end-point '
'"input_scales" is added for the exported model. If not None, the '
'specified list of floats will be hard-coded as the desired input scales.')
flags.DEFINE_enum(
'multi_scale_pool_type', 'None', ['None', 'average', 'sum'],
"If 'None' (default), the model is exported with an output end-point "
"'global_descriptors', where the global descriptor for each scale is "
"returned separately. If not 'None', the global descriptor of each scale is"
' pooled and a 1D global descriptor is returned, with output end-point '
"'global_descriptor'.")
flags.DEFINE_boolean('normalize_global_descriptor', False,
'If True, L2-normalizes global descriptor.')
def _build_tensor_info(tensor_dict):
"""Replace the dict's value by the tensor info.
Args:
tensor_dict: A dictionary contains <string, tensor>.
Returns:
dict: New dictionary contains <string, tensor_info>.
"""
return {
k: tf.compat.v1.saved_model.utils.build_tensor_info(t)
for k, t in tensor_dict.items()
}
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError('Export_path already exists.')
with tf.Graph().as_default() as g, tf.compat.v1.Session(graph=g) as sess:
# Setup the model for extraction.
model = delf_model.Delf(block3_strides=False, name='DELF')
# Initial forward pass to build model.
images = tf.zeros((1, 321, 321, 3), dtype=tf.float32)
model(images)
# Setup the multiscale extraction.
input_image = tf.compat.v1.placeholder(
tf.uint8, shape=(None, None, 3), name='input_image')
if FLAGS.input_scales_list is None:
input_scales = tf.compat.v1.placeholder(
tf.float32, shape=[None], name='input_scales')
else:
input_scales = tf.constant([float(s) for s in FLAGS.input_scales_list],
dtype=tf.float32,
shape=[len(FLAGS.input_scales_list)],
name='input_scales')
extracted_features = export_model_utils.ExtractGlobalFeatures(
input_image,
input_scales,
lambda x: model.backbone(x, training=False),
multi_scale_pool_type=FLAGS.multi_scale_pool_type,
normalize_global_descriptor=FLAGS.normalize_global_descriptor)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
model.load_weights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
named_input_tensors = {'input_image': input_image}
if FLAGS.input_scales_list is None:
named_input_tensors['input_scales'] = input_scales
# Outputs to the exported model.
named_output_tensors = {}
if FLAGS.multi_scale_pool_type == 'None':
named_output_tensors['global_descriptors'] = tf.identity(
extracted_features, name='global_descriptors')
else:
named_output_tensors['global_descriptor'] = tf.identity(
extracted_features, name='global_descriptor')
# Export the model.
signature_def = (
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=_build_tensor_info(named_input_tensors),
outputs=_build_tensor_info(named_output_tensors)))
print('Exporting trained model to:', export_path)
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
init_op = None
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=init_op)
builder.save()
if __name__ == '__main__':
app.run(main)
......@@ -42,67 +42,39 @@ flags.DEFINE_boolean('block3_strides', False,
flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.')
def _build_tensor_info(tensor_dict):
"""Replace the dict's value by the tensor info.
Args:
tensor_dict: A dictionary contains <string, tensor>.
Returns:
dict: New dictionary contains <string, tensor_info>.
"""
return {
k: tf.compat.v1.saved_model.utils.build_tensor_info(t)
for k, t in tensor_dict.items()
}
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError('Export_path already exists.')
with tf.Graph().as_default() as g, tf.compat.v1.Session(graph=g) as sess:
class _ExtractModule(tf.Module):
"""Helper module to build and save DELF model."""
def __init__(self, block3_strides, iou):
"""Initialization of DELF model.
Args:
block3_strides: bool, whether to add strides to the output of block3.
iou: IOU for non-max suppression.
"""
self._stride_factor = 2.0 if block3_strides else 1.0
self._iou = iou
# Setup the DELF model for extraction.
model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF')
# Initial forward pass to build model.
images = tf.zeros((1, 321, 321, 3), dtype=tf.float32)
model(images)
self._model = delf_model.Delf(
block3_strides=block3_strides, name='DELF')
stride_factor = 2.0 if FLAGS.block3_strides else 1.0
def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path)
# Setup the multiscale keypoint extraction.
input_image = tf.compat.v1.placeholder(
tf.uint8, shape=(None, None, 3), name='input_image')
input_abs_thres = tf.compat.v1.placeholder(
tf.float32, shape=(), name='input_abs_thres')
input_scales = tf.compat.v1.placeholder(
tf.float32, shape=[None], name='input_scales')
input_max_feature_num = tf.compat.v1.placeholder(
tf.int32, shape=(), name='input_max_feature_num')
@tf.function(input_signature=[
tf.TensorSpec(shape=[None, None, 3], dtype=tf.uint8, name='input_image'),
tf.TensorSpec(shape=[None], dtype=tf.float32, name='input_scales'),
tf.TensorSpec(shape=(), dtype=tf.int32, name='input_max_feature_num'),
tf.TensorSpec(shape=(), dtype=tf.float32, name='input_abs_thres')
])
def ExtractFeatures(self, input_image, input_scales, input_max_feature_num,
input_abs_thres):
extracted_features = export_model_utils.ExtractLocalFeatures(
input_image, input_scales, input_max_feature_num, input_abs_thres,
FLAGS.iou, lambda x: model(x, training=False), stride_factor)
self._iou, lambda x: self._model(x, training=False),
self._stride_factor)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
model.load_weights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
named_input_tensors = {
'input_image': input_image,
'input_scales': input_scales,
'input_abs_thres': input_abs_thres,
'input_max_feature_num': input_max_feature_num,
}
# Outputs to the exported model.
named_output_tensors = {}
named_output_tensors['boxes'] = tf.identity(
extracted_features[0], name='boxes')
......@@ -112,25 +84,27 @@ def main(argv):
extracted_features[2], name='scales')
named_output_tensors['scores'] = tf.identity(
extracted_features[3], name='scores')
return named_output_tensors
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError(f'Export_path {export_path} already exists. Please '
'specify a different path or delete the existing one.')
module = _ExtractModule(FLAGS.block3_strides, FLAGS.iou)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
module.LoadWeights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
# Export the model.
signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=_build_tensor_info(named_input_tensors),
outputs=_build_tensor_info(named_output_tensors))
print('Exporting trained model to:', export_path)
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
init_op = None
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def
},
main_op=init_op)
builder.save()
# Save the module
tf.saved_model.save(module, export_path)
if __name__ == '__main__':
......
......@@ -142,20 +142,21 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
keep_going = lambda j, b, f, scales, scores: tf.less(j, num_scales)
(_, output_boxes, output_features, output_scales,
output_scores) = tf.while_loop(
cond=keep_going,
body=_ProcessSingleScale,
loop_vars=[
i, output_boxes, output_features, output_scales, output_scores
],
shape_invariants=[
i.get_shape(),
tf.TensorShape([None, 4]),
tf.TensorShape([None, feature_depth]),
tf.TensorShape([None]),
tf.TensorShape([None])
],
back_prop=False)
output_scores) = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
cond=keep_going,
body=_ProcessSingleScale,
loop_vars=[
i, output_boxes, output_features, output_scales, output_scores
],
shape_invariants=[
i.get_shape(),
tf.TensorShape([None, 4]),
tf.TensorShape([None, feature_depth]),
tf.TensorShape([None]),
tf.TensorShape([None])
]))
feature_boxes = box_list.BoxList(output_boxes)
feature_boxes.add_field('features', output_features)
......@@ -169,3 +170,109 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
return final_boxes.get(), final_boxes.get_field(
'features'), final_boxes.get_field('scales'), tf.expand_dims(
final_boxes.get_field('scores'), 1)
def ExtractGlobalFeatures(image,
image_scales,
model_fn,
multi_scale_pool_type='None',
normalize_global_descriptor=False):
"""Extract global features for input image.
Args:
image: image tensor of type tf.uint8 with shape [h, w, channels].
image_scales: 1D float tensor which contains float scales used for image
pyramid construction.
model_fn: model function. Follows the signature:
* Args:
* `images`: Image tensor which is re-scaled.
* Returns:
* `global_descriptors`: Global descriptors for input images.
multi_scale_pool_type: If set, the global descriptor of each scale is pooled
and a 1D global descriptor is returned.
normalize_global_descriptor: If True, output global descriptors are
L2-normalized.
Returns:
global_descriptors: If `multi_scale_pool_type` is 'None', returns a [S, D]
float tensor. S is the number of scales, and D the global descriptor
dimensionality. Each D-dimensional entry is a global descriptor, which may
be L2-normalized depending on `normalize_global_descriptor`. If
`multi_scale_pool_type` is not 'None', returns a [D] float tensor with the
pooled global descriptor.
"""
original_image_shape_float = tf.gather(
tf.dtypes.cast(tf.shape(image), tf.float32), [0, 1])
image_tensor = gld.NormalizeImages(
image, pixel_value_offset=128.0, pixel_value_scale=128.0)
image_tensor = tf.expand_dims(image_tensor, 0, name='image/expand_dims')
def _ProcessSingleScale(scale_index, global_descriptors=None):
"""Resizes the image and runs feature extraction.
This function will be passed into tf.while_loop() and be called
repeatedly. We get the current scale by image_scales[scale_index], and
run image resizing / feature extraction. In the end, we concat the
previous global descriptors with current descriptor as the output.
Args:
scale_index: A valid index in image_scales.
global_descriptors: Global descriptor tensor with the shape of [S, D]. If
None, no previous global descriptors are used, and the output will be of
shape [1, D].
Returns:
scale_index: The next scale index for processing.
global_descriptors: A concatenated global descriptor tensor with the shape
of [S+1, D].
"""
scale = tf.gather(image_scales, scale_index)
new_image_size = tf.dtypes.cast(
tf.round(original_image_shape_float * scale), tf.int32)
resized_image = tf.image.resize(image_tensor, new_image_size)
global_descriptor = model_fn(resized_image)
if global_descriptors is None:
global_descriptors = global_descriptor
else:
global_descriptors = tf.concat([global_descriptors, global_descriptor], 0)
return scale_index + 1, global_descriptors
# Process the first scale separately, the following scales will reuse the
# graph variables.
(_, output_global) = _ProcessSingleScale(0)
i = tf.constant(1, dtype=tf.int32)
num_scales = tf.shape(image_scales)[0]
keep_going = lambda j, g: tf.less(j, num_scales)
(_, output_global) = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
cond=keep_going,
body=_ProcessSingleScale,
loop_vars=[i, output_global],
shape_invariants=[i.get_shape(),
tf.TensorShape([None, None])]))
normalization_axis = 1
if multi_scale_pool_type == 'average':
output_global = tf.reduce_mean(
output_global,
axis=0,
keepdims=False,
name='multi_scale_average_pooling')
normalization_axis = 0
elif multi_scale_pool_type == 'sum':
output_global = tf.reduce_sum(
output_global, axis=0, keepdims=False, name='multi_scale_sum_pooling')
normalization_axis = 0
if normalize_global_descriptor:
output_global = tf.nn.l2_normalize(
output_global, axis=normalization_axis, name='l2_normalization')
return output_global
......@@ -22,9 +22,14 @@ from __future__ import division
from __future__ import print_function
import functools
import os
import tempfile
from absl import logging
import h5py
import tensorflow as tf
layers = tf.keras.layers
......@@ -284,8 +289,8 @@ class ResNet50(tf.keras.Model):
else:
self.global_pooling = None
def call(self, inputs, training=True, intermediates_dict=None):
"""Call the ResNet50 model.
def build_call(self, inputs, training=True, intermediates_dict=None):
"""Building the ResNet50 model.
Args:
inputs: Images to compute features for.
......@@ -356,3 +361,79 @@ class ResNet50(tf.keras.Model):
return self.global_pooling(x)
else:
return x
def call(self, inputs, training=True, intermediates_dict=None):
"""Call the ResNet50 model.
Args:
inputs: Images to compute features for.
training: Whether model is in training phase.
intermediates_dict: `None` or dictionary. If not None, accumulate feature
maps from intermediate blocks into the dictionary. ""
Returns:
Tensor with featuremap.
"""
return self.build_call(inputs, training, intermediates_dict)
def restore_weights(self, filepath):
"""Load pretrained weights.
This function loads a .h5 file from the filepath with saved model weights
and assigns them to the model.
Args:
filepath: String, path to the .h5 file
Raises:
ValueError: if the file referenced by `filepath` does not exist.
"""
if not tf.io.gfile.exists(filepath):
raise ValueError('Unable to load weights from %s. You must provide a'
'valid file.' % (filepath))
# Create a local copy of the weights file for h5py to be able to read it.
local_filename = os.path.basename(filepath)
tmp_filename = os.path.join(tempfile.gettempdir(), local_filename)
tf.io.gfile.copy(filepath, tmp_filename, overwrite=True)
# Load the content of the weights file.
f = h5py.File(tmp_filename, mode='r')
saved_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
try:
# Iterate through all the layers assuming the max `depth` is 2.
for layer in self.layers:
if hasattr(layer, 'layers'):
for inlayer in layer.layers:
# Make sure the weights are in the saved model, and that we are in
# the innermost layer.
if inlayer.name not in saved_layer_names:
raise ValueError('Layer %s absent from the pretrained weights.'
'Unable to load its weights.' % (inlayer.name))
if hasattr(inlayer, 'layers'):
raise ValueError('Layer %s is not a depth 2 layer. Unable to load'
'its weights.' % (inlayer.name))
# Assign the weights in the current layer.
g = f[inlayer.name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = [g[weight_name] for weight_name in weight_names]
print('Setting the weights for layer %s' % (inlayer.name))
inlayer.set_weights(weight_values)
finally:
# Clean up the temporary file.
tf.io.gfile.remove(tmp_filename)
def log_weights(self):
"""Log backbone weights."""
logging.info('Logging backbone weights')
logging.info('------------------------')
for layer in self.layers:
if hasattr(layer, 'layers'):
for inlayer in layer.layers:
logging.info('Weights for layer: %s, inlayer % s', layer.name,
inlayer.name)
weights = inlayer.get_weights()
logging.info(weights)
else:
logging.info('Layer %s does not have inner layers.',
layer.name)
......@@ -43,17 +43,20 @@ flags.DEFINE_string('train_file_pattern', '/tmp/data/train*',
'File pattern of training dataset files.')
flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*',
'File pattern of validation dataset files.')
flags.DEFINE_enum('dataset_version', 'gld_v1',
['gld_v1', 'gld_v2', 'gld_v2_clean'],
'Google Landmarks dataset version, used to determine the'
'number of classes.')
flags.DEFINE_enum(
'dataset_version', 'gld_v1', ['gld_v1', 'gld_v2', 'gld_v2_clean'],
'Google Landmarks dataset version, used to determine the'
'number of classes.')
flags.DEFINE_integer('seed', 0, 'Seed to training dataset.')
flags.DEFINE_float('initial_lr', 0.001, 'Initial learning rate.')
flags.DEFINE_float('initial_lr', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('batch_size', 32, 'Global batch size.')
flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.')
flags.DEFINE_boolean('block3_strides', False, 'Whether to use block3_strides.')
flags.DEFINE_boolean('block3_strides', True, 'Whether to use block3_strides.')
flags.DEFINE_boolean('use_augmentation', True,
'Whether to use ImageNet style augmentation.')
flags.DEFINE_string(
'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
def _record_accuracy(metric, logits, labels):
......@@ -64,6 +67,10 @@ def _record_accuracy(metric, logits, labels):
def _attention_summaries(scores, global_step):
"""Record statistics of the attention score."""
tf.summary.image(
'batch_attention',
scores / tf.reduce_max(scores + 1e-3),
step=global_step)
tf.summary.scalar('attention/max', tf.reduce_max(scores), step=global_step)
tf.summary.scalar('attention/min', tf.reduce_min(scores), step=global_step)
tf.summary.scalar('attention/mean', tf.reduce_mean(scores), step=global_step)
......@@ -124,7 +131,7 @@ def main(argv):
max_iters = FLAGS.max_iters
global_batch_size = FLAGS.batch_size
image_size = 321
num_eval = 1000
num_eval_batches = int(50000 / global_batch_size)
report_interval = 100
eval_interval = 1000
save_interval = 20000
......@@ -134,9 +141,10 @@ def main(argv):
clip_val = tf.constant(10.0)
if FLAGS.debug:
tf.config.run_functions_eagerly(True)
global_batch_size = 4
max_iters = 4
num_eval = 1
max_iters = 100
num_eval_batches = 1
save_interval = 1
report_interval = 1
......@@ -159,11 +167,12 @@ def main(argv):
augmentation=False,
seed=FLAGS.seed)
train_iterator = strategy.make_dataset_iterator(train_dataset)
validation_iterator = strategy.make_dataset_iterator(validation_dataset)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
validation_dist_dataset = strategy.experimental_distribute_dataset(
validation_dataset)
train_iterator.initialize()
validation_iterator.initialize()
train_iter = iter(train_dist_dataset)
validation_iter = iter(validation_dist_dataset)
# Create a checkpoint directory to store the checkpoints.
checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')
......@@ -219,11 +228,14 @@ def main(argv):
labels = tf.clip_by_value(labels, 0, model.num_classes)
global_step = optimizer.iterations
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
# TODO(andrearaujo): we should try to unify the backprop into a single
# function, instead of applying once to descriptor then to attention.
def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients.
......@@ -344,12 +356,25 @@ def main(argv):
with tf.summary.record_if(
tf.math.equal(0, optimizer.iterations % report_interval)):
# TODO(dananghel): try to load pretrained weights at backbone creation.
# Load pretrained weights for ResNet50 trained on ImageNet.
if FLAGS.imagenet_checkpoint is not None:
logging.info('Attempting to load ImageNet pretrained weights.')
input_batch = next(train_iter)
_, _ = distributed_train_step(input_batch)
model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
logging.info('Done.')
else:
logging.info('Skip loading ImageNet pretrained weights.')
if FLAGS.debug:
model.backbone.log_weights()
global_step_value = optimizer.iterations.numpy()
while global_step_value < max_iters:
# input_batch : images(b, h, w, c), labels(b,).
try:
input_batch = train_iterator.get_next()
input_batch = next(train_iter)
except tf.errors.OutOfRangeError:
# Break if we run out of data in the dataset.
logging.info('Stopping training at global step %d, no more data',
......@@ -392,9 +417,9 @@ def main(argv):
# Validate once in {eval_interval*n, n \in N} steps.
if global_step_value % eval_interval == 0:
for i in range(num_eval):
for i in range(num_eval_batches):
try:
validation_batch = validation_iterator.get_next()
validation_batch = next(validation_iter)
desc_validation_result, attn_validation_result = (
distributed_validation_step(validation_batch))
except tf.errors.OutOfRangeError:
......@@ -416,13 +441,17 @@ def main(argv):
print(' : attn:', attn_validation_result.numpy())
# Save checkpoint once (each save_interval*n, n \in N) steps.
# TODO(andrearaujo): save only in one of the two ways. They are
# identical, the only difference is that the manager adds some extra
# prefixes and variables (eg, optimizer variables).
if global_step_value % save_interval == 0:
save_path = manager.save()
logging.info('Saved({global_step_value}) at %s', save_path)
logging.info('Saved (%d) at %s', global_step_value, save_path)
file_path = '%s/delf_weights' % FLAGS.logdir
model.save_weights(file_path, save_format='tf')
logging.info('Saved weights({global_step_value}) at %s', file_path)
logging.info('Saved weights (%d) at %s', global_step_value,
file_path)
# Reset metrics for next step.
desc_train_accuracy.reset_states()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment