Unverified Commit a9fcda17 authored by André Araujo's avatar André Araujo Committed by GitHub
Browse files

Exporting option for DELG local+global feature model (#9198)

* Merged commit includes the following changes:
253126424  by Andre Araujo:

    Scripts to compute metrics for Google Landmarks dataset.

    Also, a small fix to metric in retrieval case: avoids duplicate predicted images.

--
253118971  by Andre Araujo:

    Metrics for Google Landmarks dataset.

--
253106953  by Andre Araujo:

    Library to read files from Google Landmarks challenges.

--
250700636  by Andre Araujo:

    Handle case of aggregation extraction with empty set of input features.

--
250516819  by Andre Araujo:

    Add minimum size for DELF extractor.

--
250435822  by Andre Araujo:

    Add max_image_size/min_image_size for open-source DELF proto / module.

--
250414606  by Andre Araujo:

    Refactor extract_aggregation to allow reuse with different datasets.

--
250356863  by Andre Araujo:

    Remove unnecessary cmd_args variable from boxes_and_features_extraction.

--
249783379  by Andre Araujo:

    Create directory for writing mapping file if it does not exist.

--
249581591  by Andre Araujo:

    Refactor scripts to extract boxes and features from images in Revisited datasets.
    Also, change tf.logging.info --> print for easier logging in open source code.

--
249511821  by Andre Araujo:

    Small change to function for file/directory handling.

--
249289499  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 253126424

* Updating DELF init to adjust to latest changes

* Editing init files for python packages

* Edit D2R dataset reader to work with py3.

PiperOrigin-RevId: 253135576

* DELF package: fix import ordering

* Adding new requirements to setup.py

* Adding init file for training dir

* Merged commit includes the following changes:

FolderOrigin-RevId: /google/src/cloud/andrearaujo/delf_oss/google3/..

* Adding init file for training subdirs

* Working version of DELF training

* Internal change.

PiperOrigin-RevId: 253248648

* Fix variance loading in open-source code.

PiperOrigin-RevId: 260619120

* Separate image re-ranking as a standalone library, and add metric writing to dataset library.

PiperOrigin-RevId: 260998608

* Tool to read written D2R Revisited datasets metrics file. Test is added.

Also adds a unit test for previously-existing SaveMetricsFile function.

PiperOrigin-RevId: 263361410

* Add optional resize factor for feature extraction.

PiperOrigin-RevId: 264437080

* Fix NumPy's new version spacing changes.

PiperOrigin-RevId: 265127245

* Maker image matching function visible, and add support for RANSAC seed.

PiperOrigin-RevId: 277177468

* Avoid matplotlib failure due to missing display backend.

PiperOrigin-RevId: 287316435

* Removes tf.contrib dependency.

PiperOrigin-RevId: 288842237

* Fix tf contrib removal for feature_aggregation_extractor.

PiperOrigin-RevId: 289487669

* Merged commit includes the following changes:
309118395  by Andre Araujo:

    Make DELF open-source code compatible with TF2.

--
309067582  by Andre Araujo:

    Handle image resizing rounding properly for python extraction.

    New behavior is tested with unit tests.

--
308690144  by Andre Araujo:

    Several changes to improve DELF model/training code and make it work in TF 2.1.0:
    - Rename some files for better clarity
    - Using compat.v1 versions of functions
    - Formatting changes
    - Using more appropriate TF function names

--
308689397  by Andre Araujo:

    Internal change.

--
308341315  by Andre Araujo:

    Remove old slim dependency in DELF open-source model.

    This avoids issues with requiring old TF-v1, making it compatible with latest TF.

--
306777559  by Andre Araujo:

    Internal change

--
304505811  by Andre Araujo:

    Raise error during geometric verification if local features have different dimensionalities.

--
301739992  by Andre Araujo:

    Transform some geometric verification constants into arguments, to allow custom matching.

--
301300324  by Andre Araujo:

    Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow.

--
299919057  by Andre Araujo:

    Automated refactoring to make code Python 3 compatible.

--
297953698  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297521242  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297278247  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297270405  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297238741  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297108605  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
294676131  by Andre Araujo:

    Add option to resize images to square resolutions without aspect ratio preservation.

--
293849641  by Andre Araujo:

    Internal change.

--
293840896  by Andre Araujo:

    Changing Slim import to tf_slim codebase.

--
293661660  by Andre Araujo:

    Allow the delf training script to read from TFRecords dataset.

--
291755295  by Andre Araujo:

    Internal change.

--
291448508  by Andre Araujo:

    Internal change.

--
291414459  by Andre Araujo:

    Adding train script.

--
291384336  by Andre Araujo:

    Adding model export script and test.

--
291260565  by Andre Araujo:

    Adding placeholder for Google Landmarks dataset.

--
291205548  by Andre Araujo:

    Definition of DELF model using Keras ResNet50 as backbone.

--
289500793  by Andre Araujo:

    Add TFRecord building script for delf.

--

PiperOrigin-RevId: 309118395

* Updating README, dependency versions

* Updating training README

* Fixing init import of export_model

* Fixing init import of export_model_utils

* tkinter in INSTALL_INSTRUCTIONS

* Merged commit includes the following changes:

FolderOrigin-RevId: /google/src/cloud/andrearaujo/delf_oss/google3/..

* INSTALL_INSTRUCTIONS mentioning different cloning options

* Updating required TF version, since 2.1 is not available in pip

* Internal change.

PiperOrigin-RevId: 309136003

* Fix missing string_input_producer and start_queue_runners in TF2.

PiperOrigin-RevId: 309437512

* Handle RANSAC from skimage's latest versions.

PiperOrigin-RevId: 310170897

* DELF 2.1 version: badge and setup.py updated

* Add TF version badge in INSTALL_INSTRUCTIONS and paper badges in README

* Add paper badges in paper instructions

* Add paper badge to landmark detection instructions

* Small update to DELF training README

* Merged commit includes the following changes:
312614961  by Andre Araujo:

    Instructions/code to reproduce DELG paper results.

--
312523414  by Andre Araujo:

    Fix a minor bug when post-process extracted features, format config.delf_global_config.image_scales_ind to a list.

--
312340276  by Andre Araujo:

    Add support for global feature extraction in DELF open-source codebase.

--
311031367  by Andre Araujo:

    Add use_square_images as an option in DELF config. The default value is false. if it is set, then images are resized to square resolution before feature extraction (e.g. Starburst use case. ) Thought for a while, whether to have two constructor of DescriptorToImageTemplate, but in the end, decide to only keep one, may be less confusing.

--
310658638  by Andre Araujo:

    Option for producing local feature-based image match visualization.

--

PiperOrigin-RevId: 312614961

* DELF README update / DELG instructions

* DELF README update

* DELG instructions update

* Merged commit includes the following changes:

PiperOrigin-RevId: 312695597

* Merged commit includes the following changes:
312754894  by Andre Araujo:

    Code edits / instructions to reproduce GLDv2 results.

--

PiperOrigin-RevId: 312754894

* Markdown updates after adding GLDv2 stuff

* Small updates to DELF README

* Clarify that library must be installed before reproducing results

* Merged commit includes the following changes:
319114828  by Andre Araujo:

    Upgrade global feature model exporting to TF2.

--

PiperOrigin-RevId: 319114828

* Properly merging README

* small edits to README

* small edits to README

* small edits to README

* global feature exporting in training README

* Update to DELF README, install instructions

* Centralizing installation instructions

* Small readme update

* Fixing commas

* Mention DELG acceptance into ECCV'20

* Merged commit includes the following changes:
326723075  by Andre Araujo:

    Move image resize utility into utils.py.

--

PiperOrigin-RevId: 326723075

* Adding back matched_images_demo.png

* Merged commit includes the following changes:
327279047  by Andre Araujo:

    Adapt extractor to handle new form of joint local+global extraction.

--
326733524  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 327279047

* Updated DELG instructions after model extraction refactoring

* Updating GLDv2 paper model baseline

* Merged commit includes the following changes:
328982978  by Andre Araujo:

    Updated DELG model training so that the size of the output tensor is unchanged by the GeM pooling layer. Export global model trained with DELG global features.

--
328218938  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 328982978

* Updated training README after recent changes

* Updated training README to fix small typo

* Merged commit includes the following changes:
330022709  by Andre Araujo:

    Export joint local+global TF2 DELG model, and enable such joint extraction.

    Also, rename export_model.py -> export_local_model.py for better clarity.

    To check that the new exporting code is doing the right thing, I compared features extracted from the new exported model against those extracted from models exported with a single modality, using the same checkpoint. They are identical.

    Some other small changes:
    - small automatic reformating
    - small documentation improvements

--

PiperOrigin-RevId: 330022709

* Updated DELG exporting instructions

* Updated DELG exporting instructions: fix small typo
parent 94814c77
......@@ -44,12 +44,9 @@ def MakeExtractor(config):
ValueError: if config is invalid.
"""
# Assert the configuration.
# TODO(andrearaujo): Handle this case.
if config.use_global_features and config.use_local_features and hasattr(
config, 'is_tf2_exported') and config.is_tf2_exported:
raise ValueError(
'Joint local+global extraction is currently incompatible with '
'is_tf2_exported')
if not config.use_local_features and not config.use_global_features:
raise ValueError('Invalid config: at least one of '
'{use_local_features, use_global_features} must be True')
# Load model.
model = tf.saved_model.load(config.model_path)
......@@ -180,9 +177,22 @@ def MakeExtractor(config):
extracted_features = {}
output = None
if config.use_local_features:
if hasattr(config, 'is_tf2_exported') and config.is_tf2_exported:
predict = model.signatures['serving_default']
if hasattr(config, 'is_tf2_exported') and config.is_tf2_exported:
predict = model.signatures['serving_default']
if config.use_local_features and config.use_global_features:
if config.use_global_features:
output_dict = predict(
input_image=image_tensor,
input_scales=image_scales_tensor,
input_max_feature_num=max_feature_num_tensor,
input_abs_thres=score_threshold_tensor,
input_global_scales_ind=global_scales_ind_tensor)
output = [
output_dict['boxes'], output_dict['features'],
output_dict['scales'], output_dict['scores'],
output_dict['global_descriptors']
]
elif config.use_local_features:
output_dict = predict(
input_image=image_tensor,
input_scales=image_scales_tensor,
......@@ -193,21 +203,19 @@ def MakeExtractor(config):
output_dict['scales'], output_dict['scores']
]
else:
if config.use_global_features:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor,
global_scales_ind_tensor)
else:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor)
else:
if hasattr(config, 'is_tf2_exported') and config.is_tf2_exported:
predict = model.signatures['serving_default']
output_dict = predict(
input_image=image_tensor,
input_scales=image_scales_tensor,
input_global_scales_ind=global_scales_ind_tensor)
output = [output_dict['global_descriptors']]
else:
if config.use_local_features and config.use_global_features:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor,
global_scales_ind_tensor)
elif config.use_local_features:
output = model(image_tensor, image_scales_tensor,
score_threshold_tensor, max_feature_num_tensor)
else:
output = model(image_tensor, image_scales_tensor,
global_scales_ind_tensor)
......
......@@ -193,7 +193,7 @@ This should be used when you are only interested in having a local feature
model.
```
python3 model/export_model.py \
python3 model/export_local_model.py \
--ckpt_path=gldv2_training/delf_weights \
--export_path=gldv2_model_local \
--block3_strides
......@@ -213,7 +213,16 @@ python3 model/export_global_model.py \
### DELG local+global feature model
Work in progress. Stay tuned, this will come soon.
This should be used when you are interested in jointly extracting local and
global features.
```
python3 model/export_local_and_global_model.py \
--ckpt_path=gldv2_training/delf_weights \
--export_path=gldv2_model_local_and_global \
--delg_global_features \
--block3_strides
```
### Kaggle-compatible global feature model
......@@ -237,7 +246,9 @@ python3 model/export_global_model.py \
--normalize_global_descriptor
```
## Testing the Trained Model
## Testing the trained model
### Testing the trained local feature model
After the trained model has been exported, it can be used to extract DELF
features from 2 images of the same landmark and to perform a matching test
......@@ -310,3 +321,13 @@ python3 ../examples/match_images.py \
The generated image `matched_images.png` should look similar to this one:
![MatchedImagesDemo](./matched_images_demo.png)
### Testing the trained global (or global+local) feature model
Please follow the [DELG instructions](../delg/DELG_INSTRUCTIONS.md). The only
modification should be to pass a different `delf_config_path` when doing feature
extraction, which should point to the newly-trained model. As described in the
[DelfConfig](../../protos/delf_config.proto), you should set the
`use_local_features` and `use_global_features` in the right way, depending on
which feature modalities you are using. Note also that you should set
`is_tf2_exported` to `true`.
......@@ -89,8 +89,13 @@ class Delf(tf.keras.Model):
from conv_4 are used to compute an attention map of the same resolution.
"""
def __init__(self, block3_strides=True, name='DELF', pooling='avg',
gem_power=3.0, embedding_layer=False, embedding_layer_dim=2048):
def __init__(self,
block3_strides=True,
name='DELF',
pooling='avg',
gem_power=3.0,
embedding_layer=False,
embedding_layer_dim=2048):
"""Initialization of DELF model.
Args:
......@@ -98,8 +103,8 @@ class Delf(tf.keras.Model):
name: str, name to identify model.
pooling: str, pooling mode for global feature extraction; possible values
are 'None', 'avg', 'max', 'gem.'
gem_power: float, GeM power for GeM pooling. Only used if
pooling == 'gem'.
gem_power: float, GeM power for GeM pooling. Only used if pooling ==
'gem'.
embedding_layer: bool, whether to create an embedding layer (FC whitening
layer).
embedding_layer_dim: int, size of the embedding layer.
......@@ -125,10 +130,8 @@ class Delf(tf.keras.Model):
"""Define classifiers for training backbone and attention models."""
self.num_classes = num_classes
if desc_classification is None:
self.desc_classification = layers.Dense(num_classes,
activation=None,
kernel_regularizer=None,
name='desc_fc')
self.desc_classification = layers.Dense(
num_classes, activation=None, kernel_regularizer=None, name='desc_fc')
else:
self.desc_classification = desc_classification
self.attn_classification = layers.Dense(
......@@ -146,13 +149,17 @@ class Delf(tf.keras.Model):
return (self.attention.trainable_weights +
self.attn_classification.trainable_weights)
def call(self, input_image, training=True):
def build_call(self, input_image, training=True):
blocks = {}
self.backbone.build_call(
global_feature = self.backbone.build_call(
input_image, intermediates_dict=blocks, training=training)
features = blocks['block3'] # pytype: disable=key-error
_, probs, _ = self.attention(features, training=training)
return global_feature, probs, features
def call(self, input_image, training=True):
_, probs, features = self.build_call(input_image, training=training)
return probs, features
......@@ -15,7 +15,7 @@
# ==============================================================================
"""Export global feature tensorflow inference model.
This model includes image pyramids for multi-scale processing.
The exported model may leverage image pyramids for multi-scale processing.
"""
from __future__ import absolute_import
......@@ -52,7 +52,7 @@ flags.DEFINE_enum(
flags.DEFINE_boolean('normalize_global_descriptor', False,
'If True, L2-normalizes global descriptor.')
flags.DEFINE_boolean('delg_global_features', False,
'Whether the model is a DELG model.')
'Whether the model uses a DELG-like global feature head.')
flags.DEFINE_float(
'delg_gem_power', 3.0,
'Power for Generalized Mean pooling. Used only if --delg_global_features'
......@@ -83,9 +83,10 @@ class _ExtractModule(tf.Module):
the exported model. If not None, the specified 1D tensor of floats will
be hard-coded as the desired input scales, in conjunction with
ExtractFeaturesFixedScales.
delg_global_features: Whether the model is a DELG model.
delg_gem_power: Power for Generalized Mean pooling in the DELG model.
Used only if 'delg_global_features' is True.
delg_global_features: Whether the model uses a DELG-like global feature
head.
delg_gem_power: Power for Generalized Mean pooling in the DELG model. Used
only if 'delg_global_features' is True.
delg_embedding_layer_dim: Size of the FC whitening layer (embedding
layer). Used only if 'delg_global_features' is True.
"""
......@@ -160,10 +161,8 @@ def main(argv):
name='input_scales')
module = _ExtractModule(FLAGS.multi_scale_pool_type,
FLAGS.normalize_global_descriptor,
input_scales_tensor,
FLAGS.delg_global_features,
FLAGS.delg_gem_power,
FLAGS.delg_embedding_layer_dim)
input_scales_tensor, FLAGS.delg_global_features,
FLAGS.delg_gem_power, FLAGS.delg_embedding_layer_dim)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export DELG tensorflow inference model.
The exported model can be used to jointly extract local and global features. It
may use an image pyramid for multi-scale processing, and will include receptive
field calculation and keypoint selection for the local feature head.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from delf.python.training.model import delf_model
from delf.python.training.model import delg_model
from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_boolean('delg_global_features', True,
'Whether the model uses a DELG-like global feature head.')
flags.DEFINE_float(
'delg_gem_power', 3.0,
'Power for Generalized Mean pooling. Used only if --delg_global_features'
'is present.')
flags.DEFINE_integer(
'delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer). Used only if'
'--delg_global_features is present.')
flags.DEFINE_boolean(
'block3_strides', True,
'Whether to apply strides after block3, used for local feature head.')
flags.DEFINE_float('iou', 1.0,
'IOU for non-max suppression used in local feature head.')
class _ExtractModule(tf.Module):
"""Helper module to build and save DELG model."""
def __init__(self,
delg_global_features=True,
delg_gem_power=3.0,
delg_embedding_layer_dim=2048,
block3_strides=True,
iou=1.0):
"""Initialization of DELG model.
Args:
delg_global_features: Whether the model uses a DELG-like global feature
head.
delg_gem_power: Power for Generalized Mean pooling in the DELG model. Used
only if 'delg_global_features' is True.
delg_embedding_layer_dim: Size of the FC whitening layer (embedding
layer). Used only if 'delg_global_features' is True.
block3_strides: bool, whether to add strides to the output of block3.
iou: IOU for non-max suppression.
"""
self._stride_factor = 2.0 if block3_strides else 1.0
self._iou = iou
# Setup the DELG model for extraction.
if delg_global_features:
self._model = delg_model.Delg(
block3_strides=block3_strides,
name='DELG',
gem_power=delg_gem_power,
embedding_layer_dim=delg_embedding_layer_dim)
else:
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None, None, 3], dtype=tf.uint8, name='input_image'),
tf.TensorSpec(shape=[None], dtype=tf.float32, name='input_scales'),
tf.TensorSpec(shape=(), dtype=tf.int32, name='input_max_feature_num'),
tf.TensorSpec(shape=(), dtype=tf.float32, name='input_abs_thres'),
tf.TensorSpec(
shape=[None], dtype=tf.int32, name='input_global_scales_ind')
])
def ExtractFeatures(self, input_image, input_scales, input_max_feature_num,
input_abs_thres, input_global_scales_ind):
extracted_features = export_model_utils.ExtractLocalAndGlobalFeatures(
input_image, input_scales, input_max_feature_num, input_abs_thres,
input_global_scales_ind, self._iou,
lambda x: self._model.build_call(x, training=False),
self._stride_factor)
named_output_tensors = {}
named_output_tensors['boxes'] = tf.identity(
extracted_features[0], name='boxes')
named_output_tensors['features'] = tf.identity(
extracted_features[1], name='features')
named_output_tensors['scales'] = tf.identity(
extracted_features[2], name='scales')
named_output_tensors['scores'] = tf.identity(
extracted_features[3], name='scores')
named_output_tensors['global_descriptors'] = tf.identity(
extracted_features[4], name='global_descriptors')
return named_output_tensors
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
export_path = FLAGS.export_path
if os.path.exists(export_path):
raise ValueError(f'Export_path {export_path} already exists. Please '
'specify a different path or delete the existing one.')
module = _ExtractModule(FLAGS.delg_global_features, FLAGS.delg_gem_power,
FLAGS.delg_embedding_layer_dim, FLAGS.block3_strides,
FLAGS.iou)
# Load the weights.
checkpoint_path = FLAGS.ckpt_path
module.LoadWeights(checkpoint_path)
print('Checkpoint loaded from ', checkpoint_path)
# Save the module
tf.saved_model.save(module, export_path)
if __name__ == '__main__':
app.run(main)
......@@ -15,8 +15,9 @@
# ==============================================================================
"""Export DELF tensorflow inference model.
This model includes feature extraction, receptive field calculation and
key-point selection and outputs the selected feature descriptors.
The exported model may use an image pyramid for multi-scale processing, with
local feature extraction including receptive field calculation and keypoint
selection.
"""
from __future__ import absolute_import
......@@ -55,8 +56,7 @@ class _ExtractModule(tf.Module):
self._stride_factor = 2.0 if block3_strides else 1.0
self._iou = iou
# Setup the DELF model for extraction.
self._model = delf_model.Delf(
block3_strides=block3_strides, name='DELF')
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path)
......
......@@ -27,6 +27,9 @@ from object_detection.core import box_list
from object_detection.core import box_list_ops
# TODO(andrearaujo): Rewrite this function to be more similar to
# "ExtractLocalAndGlobalFeatures" below, leveraging autograph to avoid the need
# for tf.while loop.
def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
attention_model_fn, stride_factor):
"""Extract local features for input image.
......@@ -35,9 +38,9 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
image: image tensor of type tf.uint8 with shape [h, w, channels].
image_scales: 1D float tensor which contains float scales used for image
pyramid construction.
max_feature_num: int tensor denotes the maximum selected feature points.
abs_thres: float tensor denotes the score threshold for feature selection.
iou: float scalar denotes the iou threshold for NMS.
max_feature_num: int tensor denoting the maximum selected feature points.
abs_thres: float tensor denoting the score threshold for feature selection.
iou: float scalar denoting the iou threshold for NMS.
attention_model_fn: model function. Follows the signature:
* Args:
* `images`: Image tensor which is re-scaled.
......@@ -55,7 +58,7 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
scales such that larger image scales correspond to larger image regions,
which is compatible with keypoints detected with other techniques, for
example Congas.
scores: [N, 1] float tensor denotes the attention score.
scores: [N, 1] float tensor denoting the attention score.
"""
original_image_shape_float = tf.gather(
......@@ -66,6 +69,8 @@ def ExtractLocalFeatures(image, image_scales, max_feature_num, abs_thres, iou,
image_tensor = tf.expand_dims(image_tensor, 0, name='image/expand_dims')
# Hard code the feature depth and receptive field parameters for now.
# We need to revisit this once we change the architecture and selected
# convolutional blocks to use as local features.
rf, stride, padding = [291.0, 16.0 * stride_factor, 145.0]
feature_depth = 1024
......@@ -189,7 +194,7 @@ def ExtractGlobalFeatures(image,
`image_scales`, those with corresponding indices from this tensor.
model_fn: model function. Follows the signature:
* Args:
* `images`: Image tensor which is re-scaled.
* `images`: Batched image tensor.
* Returns:
* `global_descriptors`: Global descriptors for input images.
multi_scale_pool_type: If set, the global descriptor of each scale is pooled
......@@ -266,3 +271,138 @@ def ExtractGlobalFeatures(image,
output_global, axis=normalization_axis, name='l2_normalization')
return output_global
@tf.function
def ExtractLocalAndGlobalFeatures(image, image_scales, max_feature_num,
abs_thres, global_scales_ind, iou, model_fn,
stride_factor):
"""Extract local+global features for input image.
Args:
image: image tensor of type tf.uint8 with shape [h, w, channels].
image_scales: 1D float tensor which contains float scales used for image
pyramid construction.
max_feature_num: int tensor denoting the maximum selected feature points.
abs_thres: float tensor denoting the score threshold for feature selection.
global_scales_ind: Global feature extraction happens only for a subset of
`image_scales`, those with corresponding indices from this tensor.
iou: float scalar denoting the iou threshold for NMS.
model_fn: model function. Follows the signature:
* Args:
* `images`: Batched image tensor.
* Returns:
* `global_descriptors`: Global descriptors for input images.
* `attention_prob`: Attention map after the non-linearity.
* `feature_map`: Feature map after ResNet convolution.
stride_factor: integer accounting for striding after block3.
Returns:
boxes: [N, 4] float tensor which denotes the selected receptive boxes. N is
the number of final feature points which pass through keypoint selection
and NMS steps.
local_descriptors: [N, depth] float tensor.
feature_scales: [N] float tensor. It is the inverse of the input image
scales such that larger image scales correspond to larger image regions,
which is compatible with keypoints detected with other techniques, for
example Congas.
scores: [N, 1] float tensor denoting the attention score.
global_descriptors: [S, D] float tensor, with the global descriptors for
each scale; S is the number of scales, and D the global descriptor
dimensionality.
"""
original_image_shape_float = tf.gather(
tf.dtypes.cast(tf.shape(image), tf.float32), [0, 1])
image_tensor = gld.NormalizeImages(
image, pixel_value_offset=128.0, pixel_value_scale=128.0)
image_tensor = tf.expand_dims(image_tensor, 0, name='image/expand_dims')
# Hard code the receptive field parameters for now.
# We need to revisit this once we change the architecture and selected
# convolutional blocks to use as local features.
rf, stride, padding = [291.0, 16.0 * stride_factor, 145.0]
def _ResizeAndExtract(scale_index):
"""Helper function to resize image then extract features.
Args:
scale_index: A valid index in image_scales.
Returns:
global_descriptor: [1,D] tensor denoting the extracted global descriptor.
boxes: Box tensor with the shape of [K, 4].
local_descriptors: Local descriptor tensor with the shape of [K, depth].
scales: Scale tensor with the shape of [K].
scores: Score tensor with the shape of [K].
"""
scale = tf.gather(image_scales, scale_index)
new_image_size = tf.dtypes.cast(
tf.round(original_image_shape_float * scale), tf.int32)
resized_image = tf.image.resize(image_tensor, new_image_size)
global_descriptor, attention_prob, feature_map = model_fn(resized_image)
attention_prob = tf.squeeze(attention_prob, axis=[0])
feature_map = tf.squeeze(feature_map, axis=[0])
# Compute RF boxes and re-project them to the original image space.
rf_boxes = feature_extractor.CalculateReceptiveBoxes(
tf.shape(feature_map)[0],
tf.shape(feature_map)[1], rf, stride, padding)
rf_boxes = tf.divide(rf_boxes, scale)
attention_prob = tf.reshape(attention_prob, [-1])
feature_map = tf.reshape(feature_map, [-1, tf.shape(feature_map)[2]])
# Use attention score to select local features.
indices = tf.reshape(tf.where(attention_prob >= abs_thres), [-1])
boxes = tf.gather(rf_boxes, indices)
local_descriptors = tf.gather(feature_map, indices)
scores = tf.gather(attention_prob, indices)
scales = tf.ones_like(scores, tf.float32) / scale
return global_descriptor, boxes, local_descriptors, scales, scores
# TODO(andrearaujo): Currently, a global feature is extracted even for scales
# which are not using it. The obtained result is correct, however feature
# extraction is slower than expected. We should try to fix this in the future.
# Run first scale.
(output_global_descriptors, output_boxes, output_local_descriptors,
output_scales, output_scores) = _ResizeAndExtract(0)
if not tf.reduce_any(tf.equal(global_scales_ind, 0)):
# If global descriptor is not using the first scale, clear it out.
output_global_descriptors = tf.zeros(
[0, tf.shape(output_global_descriptors)[1]])
# Loop over subsequent scales.
num_scales = tf.shape(image_scales)[0]
for scale_index in tf.range(1, num_scales):
# Allow an undefined number of global feature scales to be extracted.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(output_global_descriptors,
tf.TensorShape([None, None]))])
(global_descriptor, boxes, local_descriptors, scales,
scores) = _ResizeAndExtract(scale_index)
output_boxes = tf.concat([output_boxes, boxes], 0)
output_local_descriptors = tf.concat(
[output_local_descriptors, local_descriptors], 0)
output_scales = tf.concat([output_scales, scales], 0)
output_scores = tf.concat([output_scores, scores], 0)
if tf.reduce_any(tf.equal(global_scales_ind, scale_index)):
output_global_descriptors = tf.concat(
[output_global_descriptors, global_descriptor], 0)
feature_boxes = box_list.BoxList(output_boxes)
feature_boxes.add_field('local_descriptors', output_local_descriptors)
feature_boxes.add_field('scales', output_scales)
feature_boxes.add_field('scores', output_scores)
nms_max_boxes = tf.minimum(max_feature_num, feature_boxes.num_boxes())
final_boxes = box_list_ops.non_max_suppression(feature_boxes, iou,
nms_max_boxes)
return (final_boxes.get(), final_boxes.get_field('local_descriptors'),
final_boxes.get_field('scales'),
tf.expand_dims(final_boxes.get_field('scores'),
1), output_global_descriptors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment