Unverified Commit 81e4a36c authored by André Araujo's avatar André Araujo Committed by GitHub
Browse files

Exporting code for DELG global feature (#9191)

* Merged commit includes the following changes:
253126424  by Andre Araujo:

    Scripts to compute metrics for Google Landmarks dataset.

    Also, a small fix to metric in retrieval case: avoids duplicate predicted images.

--
253118971  by Andre Araujo:

    Metrics for Google Landmarks dataset.

--
253106953  by Andre Araujo:

    Library to read files from Google Landmarks challenges.

--
250700636  by Andre Araujo:

    Handle case of aggregation extraction with empty set of input features.

--
250516819  by Andre Araujo:

    Add minimum size for DELF extractor.

--
250435822  by Andre Araujo:

    Add max_image_size/min_image_size for open-source DELF proto / module.

--
250414606  by Andre Araujo:

    Refactor extract_aggregation to allow reuse with different datasets.

--
250356863  by Andre Araujo:

    Remove unnecessary cmd_args variable from boxes_and_features_extraction.

--
249783379  by Andre Araujo:

    Create directory for writing mapping file if it does not exist.

--
249581591  by Andre Araujo:

    Refactor scripts to extract boxes and features from images in Revisited datasets.
    Also, change tf.logging.info --> print for easier logging in open source code.

--
249511821  by Andre Araujo:

    Small change to function for file/directory handling.

--
249289499  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 253126424

* Updating DELF init to adjust to latest changes

* Editing init files for python packages

* Edit D2R dataset reader to work with py3.

PiperOrigin-RevId: 253135576

* DELF package: fix import ordering

* Adding new requirements to setup.py

* Adding init file for training dir

* Merged commit includes the following changes:

FolderOrigin-RevId: /google/src/cloud/andrearaujo/delf_oss/google3/..

* Adding init file for training subdirs

* Working version of DELF training

* Internal change.

PiperOrigin-RevId: 253248648

* Fix variance loading in open-source code.

PiperOrigin-RevId: 260619120

* Separate image re-ranking as a standalone library, and add metric writing to dataset library.

PiperOrigin-RevId: 260998608

* Tool to read written D2R Revisited datasets metrics file. Test is added.

Also adds a unit test for previously-existing SaveMetricsFile function.

PiperOrigin-RevId: 263361410

* Add optional resize factor for feature extraction.

PiperOrigin-RevId: 264437080

* Fix NumPy's new version spacing changes.

PiperOrigin-RevId: 265127245

* Maker image matching function visible, and add support for RANSAC seed.

PiperOrigin-RevId: 277177468

* Avoid matplotlib failure due to missing display backend.

PiperOrigin-RevId: 287316435

* Removes tf.contrib dependency.

PiperOrigin-RevId: 288842237

* Fix tf contrib removal for feature_aggregation_extractor.

PiperOrigin-RevId: 289487669

* Merged commit includes the following changes:
309118395  by Andre Araujo:

    Make DELF open-source code compatible with TF2.

--
309067582  by Andre Araujo:

    Handle image resizing rounding properly for python extraction.

    New behavior is tested with unit tests.

--
308690144  by Andre Araujo:

    Several changes to improve DELF model/training code and make it work in TF 2.1.0:
    - Rename some files for better clarity
    - Using compat.v1 versions of functions
    - Formatting changes
    - Using more appropriate TF function names

--
308689397  by Andre Araujo:

    Internal change.

--
308341315  by Andre Araujo:

    Remove old slim dependency in DELF open-source model.

    This avoids issues with requiring old TF-v1, making it compatible with latest TF.

--
306777559  by Andre Araujo:

    Internal change

--
304505811  by Andre Araujo:

    Raise error during geometric verification if local features have different dimensionalities.

--
301739992  by Andre Araujo:

    Transform some geometric verification constants into arguments, to allow custom matching.

--
301300324  by Andre Araujo:

    Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow.

--
299919057  by Andre Araujo:

    Automated refactoring to make code Python 3 compatible.

--
297953698  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297521242  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297278247  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297270405  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297238741  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297108605  by Andre Araujo:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
294676131  by Andre Araujo:

    Add option to resize images to square resolutions without aspect ratio preservation.

--
293849641  by Andre Araujo:

    Internal change.

--
293840896  by Andre Araujo:

    Changing Slim import to tf_slim codebase.

--
293661660  by Andre Araujo:

    Allow the delf training script to read from TFRecords dataset.

--
291755295  by Andre Araujo:

    Internal change.

--
291448508  by Andre Araujo:

    Internal change.

--
291414459  by Andre Araujo:

    Adding train script.

--
291384336  by Andre Araujo:

    Adding model export script and test.

--
291260565  by Andre Araujo:

    Adding placeholder for Google Landmarks dataset.

--
291205548  by Andre Araujo:

    Definition of DELF model using Keras ResNet50 as backbone.

--
289500793  by Andre Araujo:

    Add TFRecord building script for delf.

--

PiperOrigin-RevId: 309118395

* Updating README, dependency versions

* Updating training README

* Fixing init import of export_model

* Fixing init import of export_model_utils

* tkinter in INSTALL_INSTRUCTIONS

* Merged commit includes the following changes:

FolderOrigin-RevId: /google/src/cloud/andrearaujo/delf_oss/google3/..

* INSTALL_INSTRUCTIONS mentioning different cloning options

* Updating required TF version, since 2.1 is not available in pip

* Internal change.

PiperOrigin-RevId: 309136003

* Fix missing string_input_producer and start_queue_runners in TF2.

PiperOrigin-RevId: 309437512

* Handle RANSAC from skimage's latest versions.

PiperOrigin-RevId: 310170897

* DELF 2.1 version: badge and setup.py updated

* Add TF version badge in INSTALL_INSTRUCTIONS and paper badges in README

* Add paper badges in paper instructions

* Add paper badge to landmark detection instructions

* Small update to DELF training README

* Merged commit includes the following changes:
312614961  by Andre Araujo:

    Instructions/code to reproduce DELG paper results.

--
312523414  by Andre Araujo:

    Fix a minor bug when post-process extracted features, format config.delf_global_config.image_scales_ind to a list.

--
312340276  by Andre Araujo:

    Add support for global feature extraction in DELF open-source codebase.

--
311031367  by Andre Araujo:

    Add use_square_images as an option in DELF config. The default value is false. if it is set, then images are resized to square resolution before feature extraction (e.g. Starburst use case. ) Thought for a while, whether to have two constructor of DescriptorToImageTemplate, but in the end, decide to only keep one, may be less confusing.

--
310658638  by Andre Araujo:

    Option for producing local feature-based image match visualization.

--

PiperOrigin-RevId: 312614961

* DELF README update / DELG instructions

* DELF README update

* DELG instructions update

* Merged commit includes the following changes:

PiperOrigin-RevId: 312695597

* Merged commit includes the following changes:
312754894  by Andre Araujo:

    Code edits / instructions to reproduce GLDv2 results.

--

PiperOrigin-RevId: 312754894

* Markdown updates after adding GLDv2 stuff

* Small updates to DELF README

* Clarify that library must be installed before reproducing results

* Merged commit includes the following changes:
319114828  by Andre Araujo:

    Upgrade global feature model exporting to TF2.

--

PiperOrigin-RevId: 319114828

* Properly merging README

* small edits to README

* small edits to README

* small edits to README

* global feature exporting in training README

* Update to DELF README, install instructions

* Centralizing installation instructions

* Small readme update

* Fixing commas

* Mention DELG acceptance into ECCV'20

* Merged commit includes the following changes:
326723075  by Andre Araujo:

    Move image resize utility into utils.py.

--

PiperOrigin-RevId: 326723075

* Adding back matched_images_demo.png

* Merged commit includes the following changes:
327279047  by Andre Araujo:

    Adapt extractor to handle new form of joint local+global extraction.

--
326733524  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 327279047

* Updated DELG instructions after model extraction refactoring

* Updating GLDv2 paper model baseline

* Merged commit includes the following changes:
328982978  by Andre Araujo:

    Updated DELG model training so that the size of the output tensor is unchanged by the GeM pooling layer. Export global model trained with DELG global features.

--
328218938  by Andre Araujo:

    Internal change.

--

PiperOrigin-RevId: 328982978

* Updated training README after recent changes

* Updated training README to fix small typo
parent 59d3d2a3
# DELF Training Instructions # DELF/DELG Training Instructions
This README documents the end-to-end process for training a landmark detection This README documents the end-to-end process for training a local and/or global
and retrieval model using the DELF library on the image feature model on the
[Google Landmarks Dataset v2](https://github.com/cvdfoundation/google-landmark) [Google Landmarks Dataset v2](https://github.com/cvdfoundation/google-landmark)
(GLDv2). This can be achieved following these steps: (GLDv2). This can be achieved following these steps:
...@@ -166,7 +166,7 @@ the batch size to `256`: ...@@ -166,7 +166,7 @@ the batch size to `256`:
It is also possible to train the model with an improved global features head as It is also possible to train the model with an improved global features head as
introduced in the [DELG paper](https://arxiv.org/abs/2001.05027). To do this, introduced in the [DELG paper](https://arxiv.org/abs/2001.05027). To do this,
specify the additional parameter `--delg_global_features` when launching the specify the additional parameter `--delg_global_features` when launching the
training, like in the following example: training, like in the following example:
``` ```
...@@ -179,13 +179,18 @@ python3 train.py \ ...@@ -179,13 +179,18 @@ python3 train.py \
--delg_global_features --delg_global_features
``` ```
*NOTE*: We are currently working on adding the autoencoder described in the DELG
paper to this codebase. Currently, it is not yet implemented here. Stay tuned!
## Exporting the Trained Model ## Exporting the Trained Model
Assuming the training output, the TensorFlow checkpoint, is in the Assuming the training output, the TensorFlow checkpoint, is in the
`gldv2_training` directory, running the following commands exports the model. `gldv2_training` directory, running the following commands exports the model.
### DELF local feature model ### DELF local feature-only model
This should be used when you are only interested in having a local feature
model.
``` ```
python3 model/export_model.py \ python3 model/export_model.py \
...@@ -194,12 +199,35 @@ python3 model/export_model.py \ ...@@ -194,12 +199,35 @@ python3 model/export_model.py \
--block3_strides --block3_strides
``` ```
### DELG global feature-only model
This should be used when you are only interested in having a global feature
model.
```
python3 model/export_global_model.py \
--ckpt_path=gldv2_training/delf_weights \
--export_path=gldv2_model_global \
--delg_global_features
```
### DELG local+global feature model
Work in progress. Stay tuned, this will come soon.
### Kaggle-compatible global feature model ### Kaggle-compatible global feature model
To export a global feature model in the format required by the To export a global feature model in the format required by the
[2020 Landmark Retrieval challenge](https://www.kaggle.com/c/landmark-retrieval-2020), [2020 Landmark Retrieval challenge](https://www.kaggle.com/c/landmark-retrieval-2020),
you can use the following command: you can use the following command:
*NOTE*: this command is helpful to use the model directly in the above-mentioned
Kaggle competition; however, this is a different format than the one required in
this DELF/DELG codebase (ie, if you export the model this way, the commands
found in the [DELG instructions](../delg/DELG_INSTRUCTIONS.md) would not work).
To export the model in a manner compatible to this codebase, use a similar
command as the "DELG global feature-only model" above.
``` ```
python3 model/export_global_model.py \ python3 model/export_global_model.py \
--ckpt_path=gldv2_training/delf_weights \ --ckpt_path=gldv2_training/delf_weights \
......
...@@ -107,7 +107,7 @@ def cosine_classifier_logits(prelogits, ...@@ -107,7 +107,7 @@ def cosine_classifier_logits(prelogits,
"""Compute cosine classifier logits using ArFace margin. """Compute cosine classifier logits using ArFace margin.
Args: Args:
prelogits: float tensor of shape [batch_size, 1, 1, embedding_layer_dim]. prelogits: float tensor of shape [batch_size, embedding_layer_dim].
labels: int tensor of shape [batch_size]. labels: int tensor of shape [batch_size].
num_classes: int, number of classes. num_classes: int, number of classes.
cosine_weights: float tensor of shape [embedding_layer_dim, num_classes]. cosine_weights: float tensor of shape [embedding_layer_dim, num_classes].
...@@ -118,11 +118,8 @@ def cosine_classifier_logits(prelogits, ...@@ -118,11 +118,8 @@ def cosine_classifier_logits(prelogits,
Returns: Returns:
logits: Float tensor [batch_size, num_classes]. logits: Float tensor [batch_size, num_classes].
""" """
# Reshape from [batch_size, 1, 1, depth] to [batch_size, depth].
squeezed_prelogits = tf.squeeze(prelogits, [1, 2])
# L2-normalize prelogits, then obtain cosine similarity. # L2-normalize prelogits, then obtain cosine similarity.
normalized_prelogits = tf.math.l2_normalize(squeezed_prelogits, axis=1) normalized_prelogits = tf.math.l2_normalize(prelogits, axis=1)
normalized_weights = tf.math.l2_normalize(cosine_weights, axis=0) normalized_weights = tf.math.l2_normalize(cosine_weights, axis=0)
cosine_sim = tf.matmul(normalized_prelogits, normalized_weights) cosine_sim = tf.matmul(normalized_prelogits, normalized_weights)
......
...@@ -29,6 +29,7 @@ from absl import flags ...@@ -29,6 +29,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from delf.python.training.model import delf_model from delf.python.training.model import delf_model
from delf.python.training.model import delg_model
from delf.python.training.model import export_model_utils from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -50,6 +51,16 @@ flags.DEFINE_enum( ...@@ -50,6 +51,16 @@ flags.DEFINE_enum(
"'global_descriptor'.") "'global_descriptor'.")
flags.DEFINE_boolean('normalize_global_descriptor', False, flags.DEFINE_boolean('normalize_global_descriptor', False,
'If True, L2-normalizes global descriptor.') 'If True, L2-normalizes global descriptor.')
flags.DEFINE_boolean('delg_global_features', False,
'Whether the model is a DELG model.')
flags.DEFINE_float(
'delg_gem_power', 3.0,
'Power for Generalized Mean pooling. Used only if --delg_global_features'
'is present.')
flags.DEFINE_integer(
'delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer). Used only if'
'--delg_global_features is present.')
class _ExtractModule(tf.Module): class _ExtractModule(tf.Module):
...@@ -58,7 +69,10 @@ class _ExtractModule(tf.Module): ...@@ -58,7 +69,10 @@ class _ExtractModule(tf.Module):
def __init__(self, def __init__(self,
multi_scale_pool_type='None', multi_scale_pool_type='None',
normalize_global_descriptor=False, normalize_global_descriptor=False,
input_scales_tensor=None): input_scales_tensor=None,
delg_global_features=False,
delg_gem_power=3.0,
delg_embedding_layer_dim=2048):
"""Initialization of global feature model. """Initialization of global feature model.
Args: Args:
...@@ -69,6 +83,11 @@ class _ExtractModule(tf.Module): ...@@ -69,6 +83,11 @@ class _ExtractModule(tf.Module):
the exported model. If not None, the specified 1D tensor of floats will the exported model. If not None, the specified 1D tensor of floats will
be hard-coded as the desired input scales, in conjunction with be hard-coded as the desired input scales, in conjunction with
ExtractFeaturesFixedScales. ExtractFeaturesFixedScales.
delg_global_features: Whether the model is a DELG model.
delg_gem_power: Power for Generalized Mean pooling in the DELG model.
Used only if 'delg_global_features' is True.
delg_embedding_layer_dim: Size of the FC whitening layer (embedding
layer). Used only if 'delg_global_features' is True.
""" """
self._multi_scale_pool_type = multi_scale_pool_type self._multi_scale_pool_type = multi_scale_pool_type
self._normalize_global_descriptor = normalize_global_descriptor self._normalize_global_descriptor = normalize_global_descriptor
...@@ -78,7 +97,14 @@ class _ExtractModule(tf.Module): ...@@ -78,7 +97,14 @@ class _ExtractModule(tf.Module):
self._input_scales_tensor = input_scales_tensor self._input_scales_tensor = input_scales_tensor
# Setup the DELF model for extraction. # Setup the DELF model for extraction.
self._model = delf_model.Delf(block3_strides=False, name='DELF') if delg_global_features:
self._model = delg_model.Delg(
block3_strides=False,
name='DELG',
gem_power=delg_gem_power,
embedding_layer_dim=delg_embedding_layer_dim)
else:
self._model = delf_model.Delf(block3_strides=False, name='DELF')
def LoadWeights(self, checkpoint_path): def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path) self._model.load_weights(checkpoint_path)
...@@ -134,7 +160,10 @@ def main(argv): ...@@ -134,7 +160,10 @@ def main(argv):
name='input_scales') name='input_scales')
module = _ExtractModule(FLAGS.multi_scale_pool_type, module = _ExtractModule(FLAGS.multi_scale_pool_type,
FLAGS.normalize_global_descriptor, FLAGS.normalize_global_descriptor,
input_scales_tensor) input_scales_tensor,
FLAGS.delg_global_features,
FLAGS.delg_gem_power,
FLAGS.delg_embedding_layer_dim)
# Load the weights. # Load the weights.
checkpoint_path = FLAGS.ckpt_path checkpoint_path = FLAGS.ckpt_path
......
...@@ -473,12 +473,10 @@ def gem_pooling(feature_map, axis, power, threshold=1e-6): ...@@ -473,12 +473,10 @@ def gem_pooling(feature_map, axis, power, threshold=1e-6):
threshold: Optional float, threshold to use for activations. threshold: Optional float, threshold to use for activations.
Returns: Returns:
pooled_feature_map: Tensor of shape [batch, 1, 1, channels] for the pooled_feature_map: Tensor of shape [batch, channels].
"channels_last" format or [batch, channels, 1, 1] for the
"channels_first" format.
""" """
return tf.pow( return tf.pow(
tf.reduce_mean(tf.pow(tf.maximum(feature_map, threshold), power), tf.reduce_mean(tf.pow(tf.maximum(feature_map, threshold), power),
axis=axis, axis=axis,
keepdims=True), keepdims=False),
1.0 / power) 1.0 / power)
...@@ -42,8 +42,8 @@ class Resnet50Test(tf.test.TestCase): ...@@ -42,8 +42,8 @@ class Resnet50Test(tf.test.TestCase):
threshold=threshold) threshold=threshold)
# Define expected result. # Define expected result.
expected_pooled_feature_map = np.array([[[[0.707107, 1.414214]]], expected_pooled_feature_map = np.array([[0.707107, 1.414214],
[[[1.0, 70.710678]]]], [1.0, 70.710678]],
dtype=float) dtype=float)
# Compare actual and expected. # Compare actual and expected.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment