Unverified Commit 4b8f7d47 authored by Dan Anghel's avatar Dan Anghel Committed by GitHub
Browse files

Autoencoder in the DELG model (#9555)



* Merged commit includes the following changes:
326369548  by Andre Araujo:

    Fix import issues.

--
326159826  by Andre Araujo:

    Changed the implementation of the cosine weights from Keras layer to tf.Variable to manually control for L2 normalization.

--
326139082  by Andre Araujo:

    Support local feature matching using ratio test.

    To allow for easily choosing which matching type to use, we rename a flag/argument and modify all related files to avoid breakages.

    Also include a small change when computing nearest neighbors for geometric matching, to parallelize computation, which saves a little bit of time during execution (argument "n_jobs=-1").

--
326119848  by Andre Araujo:

    Option to measure DELG latency taking binarization into account.

--
324316608  by Andre Araujo:

    DELG global features training.

--
323693131  by Andre Araujo:

    PY3 conversion for delf public lib.

--
321046157  by Andre Araujo:

    Purely Google refactor

--

PiperOrigin-RevId: 326369548

* Added export of delg_model module.

* README update to explain training DELG global features head

* Added guidelines for DELF hyperparameter values

* Fixed typo

* Added mention about remaining training flags.

* Merged commit includes the following changes:
334723489  by Andre Araujo:

    Backpropagate global and attention layers together.

--
334228310  by Andre Araujo:

    Enable scaling of local feature locations to the resized resolution.

--

PiperOrigin-RevId: 334723489

* Merged commit includes the following changes:
347032253  by Andre Araujo:

    Updated local and global_and_local model export scripts for exporting models trained with the autoencoder layer.

--
344312455  by Andre Araujo:

    Implement autoencoder in training pipeline.

--
341116593  by Andre Araujo:

    Reduce the default save_interval, to get more frequent checkpoints.

--
341111808  by Andre Araujo:

    Allow checkpoint restoration in DELF training, to enable resuming of training jobs.

--
340138315  by Andre Araujo:

    DELF training script: make it always save the last checkpoint.

--
338731551  by Andre Araujo:

    Add image_size flag in DELF/G OSS training script.

--
338684879  by Andre Araujo:

    Clean up summaries in DELF/G training script.

    - Previously, the call to tf.summary.record_if() was not working, which led to summaries being recorded at every step, leading to too large events files. This is fixed.
    - Previously, some summaries were computed at iteration k, while others at iteration k+1. Now, we standardize summary computations to always run after backpropagation (which means that summaries are reported for step k+1, referring to the batch k).
    - Added a new summary: number of global steps per second; useful to see how fast training is making progress.

    Also a few other small modifications are included:
    - Improved description of the train.py script.
    - Some small automatic reformattings.

--

PiperOrigin-RevId: 347032253
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent a58cd931
...@@ -80,6 +80,38 @@ class AttentionModel(tf.keras.Model): ...@@ -80,6 +80,38 @@ class AttentionModel(tf.keras.Model):
return feat, prob, score return feat, prob, score
class AutoencoderModel(tf.keras.Model):
"""Instantiates the Keras Autoencoder model."""
def __init__(self, reduced_dimension, expand_dimension, kernel_size=1,
name='autoencoder'):
"""Initialization of Autoencoder model.
Args:
reduced_dimension: int, the output dimension of the autoencoder layer.
expand_dimension: int, the input dimension of the autoencoder layer.
kernel_size: int or tuple, height and width of the 2D convolution window.
name: str, name to identify model.
"""
super(AutoencoderModel, self).__init__(name=name)
self.conv1 = layers.Conv2D(
reduced_dimension,
kernel_size,
padding='same',
name='autoenc_conv1')
self.conv2 = layers.Conv2D(
expand_dimension,
kernel_size,
activation=tf.keras.activations.relu,
padding='same',
name='autoenc_conv2')
def call(self, inputs):
dim_reduced_features = self.conv1(inputs)
dim_expanded_features = self.conv2(dim_reduced_features)
return dim_expanded_features, dim_reduced_features
class Delf(tf.keras.Model): class Delf(tf.keras.Model):
"""Instantiates Keras DELF model using ResNet50 as backbone. """Instantiates Keras DELF model using ResNet50 as backbone.
...@@ -95,7 +127,10 @@ class Delf(tf.keras.Model): ...@@ -95,7 +127,10 @@ class Delf(tf.keras.Model):
pooling='avg', pooling='avg',
gem_power=3.0, gem_power=3.0,
embedding_layer=False, embedding_layer=False,
embedding_layer_dim=2048): embedding_layer_dim=2048,
use_dim_reduction=False,
reduced_dimension=128,
dim_expand_channels=1024):
"""Initialization of DELF model. """Initialization of DELF model.
Args: Args:
...@@ -108,6 +143,14 @@ class Delf(tf.keras.Model): ...@@ -108,6 +143,14 @@ class Delf(tf.keras.Model):
embedding_layer: bool, whether to create an embedding layer (FC whitening embedding_layer: bool, whether to create an embedding layer (FC whitening
layer). layer).
embedding_layer_dim: int, size of the embedding layer. embedding_layer_dim: int, size of the embedding layer.
use_dim_reduction: Whether to integrate dimensionality reduction layers.
If True, extra layers are added to reduce the dimensionality of the
extracted features.
reduced_dimension: int, only used if use_dim_reduction is True. The output
dimension of the autoencoder layer.
dim_expand_channels: int, only used if use_dim_reduction is True. The
number of channels of the backbone block used. Default value 1024 is the
number of channels of backbone block 'block3'.
""" """
super(Delf, self).__init__(name=name) super(Delf, self).__init__(name=name)
...@@ -126,6 +169,13 @@ class Delf(tf.keras.Model): ...@@ -126,6 +169,13 @@ class Delf(tf.keras.Model):
# Attention model. # Attention model.
self.attention = AttentionModel(name='attention') self.attention = AttentionModel(name='attention')
# Autoencoder model.
self._use_dim_reduction = use_dim_reduction
if self._use_dim_reduction:
self.autoencoder = AutoencoderModel(reduced_dimension,
dim_expand_channels,
name='autoencoder')
def init_classifiers(self, num_classes, desc_classification=None): def init_classifiers(self, num_classes, desc_classification=None):
"""Define classifiers for training backbone and attention models.""" """Define classifiers for training backbone and attention models."""
self.num_classes = num_classes self.num_classes = num_classes
...@@ -156,14 +206,25 @@ class Delf(tf.keras.Model): ...@@ -156,14 +206,25 @@ class Delf(tf.keras.Model):
# https://arxiv.org/abs/2001.05027. # https://arxiv.org/abs/2001.05027.
block3 = backbone_blocks['block3'] # pytype: disable=key-error block3 = backbone_blocks['block3'] # pytype: disable=key-error
block3 = tf.stop_gradient(block3) block3 = tf.stop_gradient(block3)
attn_prelogits, attn_scores, _ = self.attention(block3, training=training) if self._use_dim_reduction:
return desc_prelogits, attn_prelogits, attn_scores, backbone_blocks (dim_expanded_features, dim_reduced_features) = self.autoencoder(block3)
attn_prelogits, attn_scores, _ = self.attention(dim_expanded_features,
training=training)
else:
attn_prelogits, attn_scores, _ = self.attention(block3, training=training)
dim_expanded_features = None
dim_reduced_features = None
return (desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
dim_expanded_features, dim_reduced_features)
def build_call(self, input_image, training=True): def build_call(self, input_image, training=True):
(global_feature, _, attn_scores, (global_feature, _, attn_scores, backbone_blocks, _,
backbone_blocks) = self.global_and_local_forward_pass(input_image, dim_reduced_features) = self.global_and_local_forward_pass(input_image,
training) training)
features = backbone_blocks['block3'] # pytype: disable=key-error if self._use_dim_reduction:
features = dim_reduced_features
else:
features = backbone_blocks['block3'] # pytype: disable=key-error
return global_feature, attn_scores, features return global_feature, attn_scores, features
def call(self, input_image, training=True): def call(self, input_image, training=True):
......
...@@ -88,7 +88,7 @@ class DelfTest(tf.test.TestCase, parameterized.TestCase): ...@@ -88,7 +88,7 @@ class DelfTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss, global_batch_size=batch_size) per_example_loss, global_batch_size=batch_size)
with tf.GradientTape() as gradient_tape: with tf.GradientTape() as gradient_tape:
(desc_prelogits, attn_prelogits, _, (desc_prelogits, attn_prelogits, _, _, _,
_) = model.global_and_local_forward_pass(images) _) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier. # Calculate global loss by applying the descriptor classifier.
desc_logits = model.desc_classification(desc_prelogits) desc_logits = model.desc_classification(desc_prelogits)
......
...@@ -46,7 +46,10 @@ class Delg(delf_model.Delf): ...@@ -46,7 +46,10 @@ class Delg(delf_model.Delf):
gem_power=3.0, gem_power=3.0,
embedding_layer_dim=2048, embedding_layer_dim=2048,
scale_factor_init=45.25, # sqrt(2048) scale_factor_init=45.25, # sqrt(2048)
arcface_margin=0.1): arcface_margin=0.1,
use_dim_reduction=False,
reduced_dimension=128,
dim_expand_channels=1024):
"""Initialization of DELG model. """Initialization of DELG model.
Args: Args:
...@@ -56,6 +59,14 @@ class Delg(delf_model.Delf): ...@@ -56,6 +59,14 @@ class Delg(delf_model.Delf):
embedding_layer_dim : int, dimension of the embedding layer. embedding_layer_dim : int, dimension of the embedding layer.
scale_factor_init: float. scale_factor_init: float.
arcface_margin: float, ArcFace margin. arcface_margin: float, ArcFace margin.
use_dim_reduction: Whether to integrate dimensionality reduction layers.
If True, extra layers are added to reduce the dimensionality of the
extracted features.
reduced_dimension: Only used if use_dim_reduction is True, the output
dimension of the dim_reduction layer.
dim_expand_channels: Only used if use_dim_reduction is True, the
number of channels of the backbone block used. Default value 1024 is the
number of channels of backbone block 'block3'.
""" """
logging.info('Creating Delg model, gem_power %d, embedding_layer_dim %d', logging.info('Creating Delg model, gem_power %d, embedding_layer_dim %d',
gem_power, embedding_layer_dim) gem_power, embedding_layer_dim)
...@@ -64,7 +75,10 @@ class Delg(delf_model.Delf): ...@@ -64,7 +75,10 @@ class Delg(delf_model.Delf):
pooling='gem', pooling='gem',
gem_power=gem_power, gem_power=gem_power,
embedding_layer=True, embedding_layer=True,
embedding_layer_dim=embedding_layer_dim) embedding_layer_dim=embedding_layer_dim,
use_dim_reduction=use_dim_reduction,
reduced_dimension=reduced_dimension,
dim_expand_channels=dim_expand_channels)
self._embedding_layer_dim = embedding_layer_dim self._embedding_layer_dim = embedding_layer_dim
self._scale_factor_init = scale_factor_init self._scale_factor_init = scale_factor_init
self._arcface_margin = arcface_margin self._arcface_margin = arcface_margin
......
...@@ -36,11 +36,12 @@ from delf.python.training.model import export_model_utils ...@@ -36,11 +36,12 @@ from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights', flags.DEFINE_string(
'Path to saved checkpoint.') 'ckpt_path', '/tmp/delf-logdir/delf-weights', 'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.') flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_boolean('delg_global_features', True, flags.DEFINE_boolean(
'Whether the model uses a DELG-like global feature head.') 'delg_global_features', True,
'Whether the model uses a DELG-like global feature head.')
flags.DEFINE_float( flags.DEFINE_float(
'delg_gem_power', 3.0, 'delg_gem_power', 3.0,
'Power for Generalized Mean pooling. Used only if --delg_global_features' 'Power for Generalized Mean pooling. Used only if --delg_global_features'
...@@ -52,8 +53,20 @@ flags.DEFINE_integer( ...@@ -52,8 +53,20 @@ flags.DEFINE_integer(
flags.DEFINE_boolean( flags.DEFINE_boolean(
'block3_strides', True, 'block3_strides', True,
'Whether to apply strides after block3, used for local feature head.') 'Whether to apply strides after block3, used for local feature head.')
flags.DEFINE_float('iou', 1.0, flags.DEFINE_float(
'IOU for non-max suppression used in local feature head.') 'iou', 1.0, 'IOU for non-max suppression used in local feature head.')
flags.DEFINE_boolean(
'use_autoencoder', True,
'Whether the exported model should use an autoencoder.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
class _ExtractModule(tf.Module): class _ExtractModule(tf.Module):
...@@ -86,9 +99,17 @@ class _ExtractModule(tf.Module): ...@@ -86,9 +99,17 @@ class _ExtractModule(tf.Module):
block3_strides=block3_strides, block3_strides=block3_strides,
name='DELG', name='DELG',
gem_power=delg_gem_power, gem_power=delg_gem_power,
embedding_layer_dim=delg_embedding_layer_dim) embedding_layer_dim=delg_embedding_layer_dim,
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
else: else:
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF') self._model = delf_model.Delf(
block3_strides=block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
def LoadWeights(self, checkpoint_path): def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path) self._model.load_weights(checkpoint_path)
......
...@@ -35,12 +35,24 @@ from delf.python.training.model import export_model_utils ...@@ -35,12 +35,24 @@ from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights', flags.DEFINE_string(
'Path to saved checkpoint.') 'ckpt_path', '/tmp/delf-logdir/delf-weights', 'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.') flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_boolean('block3_strides', False, flags.DEFINE_boolean(
'Whether to apply strides after block3.') 'block3_strides', True, 'Whether to apply strides after block3.')
flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.') flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.')
flags.DEFINE_boolean(
'use_autoencoder', True,
'Whether the exported model should use an autoencoder.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
class _ExtractModule(tf.Module): class _ExtractModule(tf.Module):
...@@ -56,7 +68,12 @@ class _ExtractModule(tf.Module): ...@@ -56,7 +68,12 @@ class _ExtractModule(tf.Module):
self._stride_factor = 2.0 if block3_strides else 1.0 self._stride_factor = 2.0 if block3_strides else 1.0
self._iou = iou self._iou = iou
# Setup the DELF model for extraction. # Setup the DELF model for extraction.
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF') self._model = delf_model.Delf(
block3_strides=block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
def LoadWeights(self, checkpoint_path): def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path) self._model.load_weights(checkpoint_path)
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Training script for DELF on Google Landmarks Dataset. """Training script for DELF/G on Google Landmarks Dataset.
Script to train DELF using classification loss on Google Landmarks Dataset Uses classification loss, with MirroredStrategy, to support running on multiple
using MirroredStrategy to so it can run on multiple GPUs. GPUs.
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -24,6 +24,7 @@ from __future__ import division ...@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import time
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -40,36 +41,63 @@ FLAGS = flags.FLAGS ...@@ -40,36 +41,63 @@ FLAGS = flags.FLAGS
flags.DEFINE_boolean('debug', False, 'Debug mode.') flags.DEFINE_boolean('debug', False, 'Debug mode.')
flags.DEFINE_string('logdir', '/tmp/delf', 'WithTensorBoard logdir.') flags.DEFINE_string('logdir', '/tmp/delf', 'WithTensorBoard logdir.')
flags.DEFINE_string('train_file_pattern', '/tmp/data/train*', flags.DEFINE_string(
'File pattern of training dataset files.') 'train_file_pattern', '/tmp/data/train*',
flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*', 'File pattern of training dataset files.')
'File pattern of validation dataset files.') flags.DEFINE_string(
'validation_file_pattern', '/tmp/data/validation*',
'File pattern of validation dataset files.')
flags.DEFINE_enum( flags.DEFINE_enum(
'dataset_version', 'gld_v1', ['gld_v1', 'gld_v2', 'gld_v2_clean'], 'dataset_version', 'gld_v1', ['gld_v1', 'gld_v2', 'gld_v2_clean'],
'Google Landmarks dataset version, used to determine the' 'Google Landmarks dataset version, used to determine the number of '
'number of classes.') 'classes.')
flags.DEFINE_integer('seed', 0, 'Seed to training dataset.') flags.DEFINE_integer('seed', 0, 'Seed to training dataset.')
flags.DEFINE_float('initial_lr', 0.01, 'Initial learning rate.') flags.DEFINE_float('initial_lr', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('batch_size', 32, 'Global batch size.') flags.DEFINE_integer('batch_size', 32, 'Global batch size.')
flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.') flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.')
flags.DEFINE_boolean('block3_strides', True, 'Whether to use block3_strides.') flags.DEFINE_boolean('block3_strides', True, 'Whether to use block3_strides.')
flags.DEFINE_boolean('use_augmentation', True, flags.DEFINE_boolean(
'Whether to use ImageNet style augmentation.') 'use_augmentation', True, 'Whether to use ImageNet style augmentation.')
flags.DEFINE_string( flags.DEFINE_string(
'imagenet_checkpoint', None, 'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.') 'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
flags.DEFINE_float('attention_loss_weight', 1.0, flags.DEFINE_float(
'Weight to apply to the attention loss when calculating the ' 'attention_loss_weight', 1.0,
'total loss of the model.') 'Weight to apply to the attention loss when calculating the '
flags.DEFINE_boolean('delg_global_features', False, 'total loss of the model.')
'Whether to train a DELG model.') flags.DEFINE_boolean(
flags.DEFINE_float('delg_gem_power', 3.0, 'Power for Generalized Mean pooling.') 'delg_global_features', False, 'Whether to train a DELG model.')
flags.DEFINE_integer('delg_embedding_layer_dim', 2048, flags.DEFINE_float(
'Size of the FC whitening layer (embedding layer).') 'delg_gem_power', 3.0,
flags.DEFINE_float('delg_scale_factor_init', 45.25, 'Power for Generalized Mean pooling. Used only if '
('Initial value of the scaling factor of the cosine logits.' 'delg_global_features=True.')
'The default value is sqrt(2048).')) flags.DEFINE_integer(
flags.DEFINE_float('delg_arcface_margin', 0.1, 'ArcFace margin.') 'delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer). Used only if'
'delg_global_features:True.')
flags.DEFINE_float(
'delg_scale_factor_init', 45.25,
'Initial value of the scaling factor of the cosine logits. The default '
'value is sqrt(2048). Used only if delg_global_features=True.')
flags.DEFINE_float(
'delg_arcface_margin', 0.1,
'ArcFace margin. Used only if delg_global_features=True.')
flags.DEFINE_integer('image_size', 321, 'Size of each image side to use.')
flags.DEFINE_boolean(
'use_autoencoder', True, 'Whether to train an autoencoder.')
flags.DEFINE_float(
'reconstruction_loss_weight', 10.0,
'Weight to apply to the reconstruction loss from the autoencoder when'
'calculating total loss of the model. Used only if use_autoencoder=True.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
def _record_accuracy(metric, logits, labels): def _record_accuracy(metric, logits, labels):
...@@ -104,14 +132,23 @@ def _attention_summaries(scores, global_step): ...@@ -104,14 +132,23 @@ def _attention_summaries(scores, global_step):
def create_model(num_classes): def create_model(num_classes):
"""Define DELF model, and initialize classifiers.""" """Define DELF model, and initialize classifiers."""
if FLAGS.delg_global_features: if FLAGS.delg_global_features:
model = delg_model.Delg(block3_strides=FLAGS.block3_strides, model = delg_model.Delg(
name='DELG', block3_strides=FLAGS.block3_strides,
gem_power=FLAGS.delg_gem_power, name='DELG',
embedding_layer_dim=FLAGS.delg_embedding_layer_dim, gem_power=FLAGS.delg_gem_power,
scale_factor_init=FLAGS.delg_scale_factor_init, embedding_layer_dim=FLAGS.delg_embedding_layer_dim,
arcface_margin=FLAGS.delg_arcface_margin) scale_factor_init=FLAGS.delg_scale_factor_init,
arcface_margin=FLAGS.delg_arcface_margin,
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
else: else:
model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF') model = delf_model.Delf(
block3_strides=FLAGS.block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
model.init_classifiers(num_classes) model.init_classifiers(num_classes)
return model return model
...@@ -151,11 +188,11 @@ def main(argv): ...@@ -151,11 +188,11 @@ def main(argv):
max_iters = FLAGS.max_iters max_iters = FLAGS.max_iters
global_batch_size = FLAGS.batch_size global_batch_size = FLAGS.batch_size
image_size = 321 image_size = FLAGS.image_size
num_eval_batches = int(50000 / global_batch_size) num_eval_batches = int(50000 / global_batch_size)
report_interval = 100 report_interval = 100
eval_interval = 1000 eval_interval = 1000
save_interval = 20000 save_interval = 1000
initial_lr = FLAGS.initial_lr initial_lr = FLAGS.initial_lr
...@@ -167,7 +204,7 @@ def main(argv): ...@@ -167,7 +204,7 @@ def main(argv):
max_iters = 100 max_iters = 100
num_eval_batches = 1 num_eval_batches = 1
save_interval = 1 save_interval = 1
report_interval = 1 report_interval = 10
# Determine the number of classes based on the version of the dataset. # Determine the number of classes based on the version of the dataset.
gld_info = gld.GoogleLandmarksInfo() gld_info = gld.GoogleLandmarksInfo()
...@@ -238,7 +275,12 @@ def main(argv): ...@@ -238,7 +275,12 @@ def main(argv):
# Setup checkpoint directory. # Setup checkpoint directory.
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager( manager = tf.train.CheckpointManager(
checkpoint, checkpoint_prefix, max_to_keep=3) checkpoint,
checkpoint_prefix,
max_to_keep=10,
keep_checkpoint_every_n_hours=3)
# Restores the checkpoint, if existing.
checkpoint.restore(manager.latest_checkpoint)
# ------------------------------------------------------------ # ------------------------------------------------------------
# Train step to run on one GPU. # Train step to run on one GPU.
...@@ -248,13 +290,6 @@ def main(argv): ...@@ -248,13 +290,6 @@ def main(argv):
# Temporary workaround to avoid some corrupted labels. # Temporary workaround to avoid some corrupted labels.
labels = tf.clip_by_value(labels, 0, model.num_classes) labels = tf.clip_by_value(labels, 0, model.num_classes)
global_step = optimizer.iterations
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
def _backprop_loss(tape, loss, weights): def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients. """Backpropogate losses using clipped gradients.
...@@ -270,8 +305,8 @@ def main(argv): ...@@ -270,8 +305,8 @@ def main(argv):
# Record gradients and loss through backbone. # Record gradients and loss through backbone.
with tf.GradientTape() as gradient_tape: with tf.GradientTape() as gradient_tape:
# Make a forward pass to calculate prelogits. # Make a forward pass to calculate prelogits.
(desc_prelogits, attn_prelogits, attn_scores, (desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
backbone_blocks) = model.global_and_local_forward_pass(images) dim_expanded_features, _) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier. # Calculate global loss by applying the descriptor classifier.
if FLAGS.delg_global_features: if FLAGS.delg_global_features:
...@@ -284,18 +319,36 @@ def main(argv): ...@@ -284,18 +319,36 @@ def main(argv):
attn_logits = model.attn_classification(attn_prelogits) attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits) attn_loss = compute_loss(labels, attn_logits)
# Cumulate global loss and attention loss. # Calculate reconstruction loss between the attention prelogits and the
total_loss = desc_loss + FLAGS.attention_loss_weight * attn_loss # backbone.
if FLAGS.use_autoencoder:
block3 = tf.stop_gradient(backbone_blocks['block3'])
reconstruction_loss = tf.math.reduce_mean(
tf.keras.losses.MSE(block3, dim_expanded_features))
else:
reconstruction_loss = 0
# Cumulate global loss, attention loss and reconstruction loss.
total_loss = (desc_loss
+ FLAGS.attention_loss_weight * attn_loss
+ FLAGS.reconstruction_loss_weight * reconstruction_loss)
# Perform backpropagation through the descriptor layer and attention layer # Perform backpropagation through the descriptor and attention layers
# together. # together. Note that this will increment the number of iterations of
# "optimizer".
_backprop_loss(gradient_tape, total_loss, model.trainable_weights) _backprop_loss(gradient_tape, total_loss, model.trainable_weights)
# Report scaling factor for cosine logits for a DELG model. # Step number, for summary purposes.
if FLAGS.delg_global_features: global_step = optimizer.iterations
tf.summary.scalar('desc/scale_factor', model.scale_factor,
step=global_step) # Input image-related summaries.
# Report attention and sparsity summaries. tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
# Attention and sparsity summaries.
_attention_summaries(attn_scores, global_step) _attention_summaries(attn_scores, global_step)
activations_zero_fractions = { activations_zero_fractions = {
'sparsity/%s' % k: tf.nn.zero_fraction(v) 'sparsity/%s' % k: tf.nn.zero_fraction(v)
...@@ -303,12 +356,17 @@ def main(argv): ...@@ -303,12 +356,17 @@ def main(argv):
} }
for k, v in activations_zero_fractions.items(): for k, v in activations_zero_fractions.items():
tf.summary.scalar(k, v, step=global_step) tf.summary.scalar(k, v, step=global_step)
# Record descriptor train accuracy.
# Scaling factor summary for cosine logits for a DELG model.
if FLAGS.delg_global_features:
tf.summary.scalar(
'desc/scale_factor', model.scale_factor, step=global_step)
# Record train accuracies.
_record_accuracy(desc_train_accuracy, desc_logits, labels) _record_accuracy(desc_train_accuracy, desc_logits, labels)
# Record attention train accuracy.
_record_accuracy(attn_train_accuracy, attn_logits, labels) _record_accuracy(attn_train_accuracy, attn_logits, labels)
return desc_loss, attn_loss return desc_loss, attn_loss, reconstruction_loss
# ------------------------------------------------------------ # ------------------------------------------------------------
def validation_step(inputs): def validation_step(inputs):
...@@ -350,7 +408,7 @@ def main(argv): ...@@ -350,7 +408,7 @@ def main(argv):
def distributed_train_step(dataset_inputs): def distributed_train_step(dataset_inputs):
"""Get the actual losses.""" """Get the actual losses."""
# Each (desc, attn) is a list of 3 losses - crossentropy, reg, total. # Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
desc_per_replica_loss, attn_per_replica_loss = ( desc_per_replica_loss, attn_per_replica_loss, recon_per_replica_loss = (
strategy.run(train_step, args=(dataset_inputs,))) strategy.run(train_step, args=(dataset_inputs,)))
# Reduce over the replicas. # Reduce over the replicas.
...@@ -358,8 +416,10 @@ def main(argv): ...@@ -358,8 +416,10 @@ def main(argv):
tf.distribute.ReduceOp.SUM, desc_per_replica_loss, axis=None) tf.distribute.ReduceOp.SUM, desc_per_replica_loss, axis=None)
attn_global_loss = strategy.reduce( attn_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, attn_per_replica_loss, axis=None) tf.distribute.ReduceOp.SUM, attn_per_replica_loss, axis=None)
recon_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, recon_per_replica_loss, axis=None)
return desc_global_loss, attn_global_loss return desc_global_loss, attn_global_loss, recon_global_loss
@tf.function @tf.function
def distributed_validation_step(dataset_inputs): def distributed_validation_step(dataset_inputs):
...@@ -368,15 +428,16 @@ def main(argv): ...@@ -368,15 +428,16 @@ def main(argv):
# ------------------------------------------------------------ # ------------------------------------------------------------
# *** TRAIN LOOP *** # *** TRAIN LOOP ***
with summary_writer.as_default(): with summary_writer.as_default():
with tf.summary.record_if( record_cond = lambda: tf.equal(optimizer.iterations % report_interval, 0)
tf.math.equal(0, optimizer.iterations % report_interval)): with tf.summary.record_if(record_cond):
global_step_value = optimizer.iterations.numpy()
# TODO(dananghel): try to load pretrained weights at backbone creation. # TODO(dananghel): try to load pretrained weights at backbone creation.
# Load pretrained weights for ResNet50 trained on ImageNet. # Load pretrained weights for ResNet50 trained on ImageNet.
if FLAGS.imagenet_checkpoint is not None: if (FLAGS.imagenet_checkpoint is not None) and (not global_step_value):
logging.info('Attempting to load ImageNet pretrained weights.') logging.info('Attempting to load ImageNet pretrained weights.')
input_batch = next(train_iter) input_batch = next(train_iter)
_, _ = distributed_train_step(input_batch) _, _, _ = distributed_train_step(input_batch)
model.backbone.restore_weights(FLAGS.imagenet_checkpoint) model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
logging.info('Done.') logging.info('Done.')
else: else:
...@@ -384,9 +445,9 @@ def main(argv): ...@@ -384,9 +445,9 @@ def main(argv):
if FLAGS.debug: if FLAGS.debug:
model.backbone.log_weights() model.backbone.log_weights()
global_step_value = optimizer.iterations.numpy() last_summary_step_value = None
last_summary_time = None
while global_step_value < max_iters: while global_step_value < max_iters:
# input_batch : images(b, h, w, c), labels(b,). # input_batch : images(b, h, w, c), labels(b,).
try: try:
input_batch = next(train_iter) input_batch = next(train_iter)
...@@ -396,24 +457,27 @@ def main(argv): ...@@ -396,24 +457,27 @@ def main(argv):
global_step_value) global_step_value)
break break
# Set learning rate for optimizer to use. # Set learning rate and run the training step over num_gpu gpus.
optimizer.learning_rate = _learning_rate_schedule(
optimizer.iterations.numpy(), max_iters, initial_lr)
desc_dist_loss, attn_dist_loss, recon_dist_loss = (
distributed_train_step(input_batch))
# Step number, to be used for summary/logging.
global_step = optimizer.iterations global_step = optimizer.iterations
global_step_value = global_step.numpy() global_step_value = global_step.numpy()
learning_rate = _learning_rate_schedule(global_step_value, max_iters, # LR, losses and accuracies summaries.
initial_lr)
optimizer.learning_rate = learning_rate
tf.summary.scalar( tf.summary.scalar(
'learning_rate', optimizer.learning_rate, step=global_step) 'learning_rate', optimizer.learning_rate, step=global_step)
# Run the training step over num_gpu gpus.
desc_dist_loss, attn_dist_loss = distributed_train_step(input_batch)
# Log losses and accuracies to tensorboard.
tf.summary.scalar( tf.summary.scalar(
'loss/desc/crossentropy', desc_dist_loss, step=global_step) 'loss/desc/crossentropy', desc_dist_loss, step=global_step)
tf.summary.scalar( tf.summary.scalar(
'loss/attn/crossentropy', attn_dist_loss, step=global_step) 'loss/attn/crossentropy', attn_dist_loss, step=global_step)
if FLAGS.use_autoencoder:
tf.summary.scalar(
'loss/recon/mse', recon_dist_loss, step=global_step)
tf.summary.scalar( tf.summary.scalar(
'train_accuracy/desc', 'train_accuracy/desc',
desc_train_accuracy.result(), desc_train_accuracy.result(),
...@@ -423,6 +487,19 @@ def main(argv): ...@@ -423,6 +487,19 @@ def main(argv):
attn_train_accuracy.result(), attn_train_accuracy.result(),
step=global_step) step=global_step)
# Summary for number of global steps taken per second.
current_time = time.time()
if (last_summary_step_value is not None and
last_summary_time is not None):
tf.summary.scalar(
'global_steps_per_sec',
(global_step_value - last_summary_step_value) /
(current_time - last_summary_time),
step=global_step)
if tf.summary.should_record_summaries().numpy():
last_summary_step_value = global_step_value
last_summary_time = current_time
# Print to console if running locally. # Print to console if running locally.
if FLAGS.debug: if FLAGS.debug:
if global_step_value % report_interval == 0: if global_step_value % report_interval == 0:
...@@ -455,12 +532,14 @@ def main(argv): ...@@ -455,12 +532,14 @@ def main(argv):
print('Validation: desc:', desc_validation_result.numpy()) print('Validation: desc:', desc_validation_result.numpy())
print(' : attn:', attn_validation_result.numpy()) print(' : attn:', attn_validation_result.numpy())
# Save checkpoint once (each save_interval*n, n \in N) steps. # Save checkpoint once (each save_interval*n, n \in N) steps, or if
# this is the last iteration.
# TODO(andrearaujo): save only in one of the two ways. They are # TODO(andrearaujo): save only in one of the two ways. They are
# identical, the only difference is that the manager adds some extra # identical, the only difference is that the manager adds some extra
# prefixes and variables (eg, optimizer variables). # prefixes and variables (eg, optimizer variables).
if global_step_value % save_interval == 0: if (global_step_value %
save_path = manager.save() save_interval == 0) or (global_step_value >= max_iters):
save_path = manager.save(checkpoint_number=global_step_value)
logging.info('Saved (%d) at %s', global_step_value, save_path) logging.info('Saved (%d) at %s', global_step_value, save_path)
file_path = '%s/delf_weights' % FLAGS.logdir file_path = '%s/delf_weights' % FLAGS.logdir
...@@ -476,9 +555,6 @@ def main(argv): ...@@ -476,9 +555,6 @@ def main(argv):
desc_validation_accuracy.reset_states() desc_validation_accuracy.reset_states()
attn_validation_accuracy.reset_states() attn_validation_accuracy.reset_states()
if global_step.numpy() > max_iters:
break
logging.info('Finished training for %d steps.', max_iters) logging.info('Finished training for %d steps.', max_iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment