Unverified Commit 4b8f7d47 authored by Dan Anghel's avatar Dan Anghel Committed by GitHub
Browse files

Autoencoder in the DELG model (#9555)



* Merged commit includes the following changes:
326369548  by Andre Araujo:

    Fix import issues.

--
326159826  by Andre Araujo:

    Changed the implementation of the cosine weights from Keras layer to tf.Variable to manually control for L2 normalization.

--
326139082  by Andre Araujo:

    Support local feature matching using ratio test.

    To allow for easily choosing which matching type to use, we rename a flag/argument and modify all related files to avoid breakages.

    Also include a small change when computing nearest neighbors for geometric matching, to parallelize computation, which saves a little bit of time during execution (argument "n_jobs=-1").

--
326119848  by Andre Araujo:

    Option to measure DELG latency taking binarization into account.

--
324316608  by Andre Araujo:

    DELG global features training.

--
323693131  by Andre Araujo:

    PY3 conversion for delf public lib.

--
321046157  by Andre Araujo:

    Purely Google refactor

--

PiperOrigin-RevId: 326369548

* Added export of delg_model module.

* README update to explain training DELG global features head

* Added guidelines for DELF hyperparameter values

* Fixed typo

* Added mention about remaining training flags.

* Merged commit includes the following changes:
334723489  by Andre Araujo:

    Backpropagate global and attention layers together.

--
334228310  by Andre Araujo:

    Enable scaling of local feature locations to the resized resolution.

--

PiperOrigin-RevId: 334723489

* Merged commit includes the following changes:
347032253  by Andre Araujo:

    Updated local and global_and_local model export scripts for exporting models trained with the autoencoder layer.

--
344312455  by Andre Araujo:

    Implement autoencoder in training pipeline.

--
341116593  by Andre Araujo:

    Reduce the default save_interval, to get more frequent checkpoints.

--
341111808  by Andre Araujo:

    Allow checkpoint restoration in DELF training, to enable resuming of training jobs.

--
340138315  by Andre Araujo:

    DELF training script: make it always save the last checkpoint.

--
338731551  by Andre Araujo:

    Add image_size flag in DELF/G OSS training script.

--
338684879  by Andre Araujo:

    Clean up summaries in DELF/G training script.

    - Previously, the call to tf.summary.record_if() was not working, which led to summaries being recorded at every step, leading to too large events files. This is fixed.
    - Previously, some summaries were computed at iteration k, while others at iteration k+1. Now, we standardize summary computations to always run after backpropagation (which means that summaries are reported for step k+1, referring to the batch k).
    - Added a new summary: number of global steps per second; useful to see how fast training is making progress.

    Also a few other small modifications are included:
    - Improved description of the train.py script.
    - Some small automatic reformattings.

--

PiperOrigin-RevId: 347032253
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent a58cd931
......@@ -80,6 +80,38 @@ class AttentionModel(tf.keras.Model):
return feat, prob, score
class AutoencoderModel(tf.keras.Model):
"""Instantiates the Keras Autoencoder model."""
def __init__(self, reduced_dimension, expand_dimension, kernel_size=1,
name='autoencoder'):
"""Initialization of Autoencoder model.
Args:
reduced_dimension: int, the output dimension of the autoencoder layer.
expand_dimension: int, the input dimension of the autoencoder layer.
kernel_size: int or tuple, height and width of the 2D convolution window.
name: str, name to identify model.
"""
super(AutoencoderModel, self).__init__(name=name)
self.conv1 = layers.Conv2D(
reduced_dimension,
kernel_size,
padding='same',
name='autoenc_conv1')
self.conv2 = layers.Conv2D(
expand_dimension,
kernel_size,
activation=tf.keras.activations.relu,
padding='same',
name='autoenc_conv2')
def call(self, inputs):
dim_reduced_features = self.conv1(inputs)
dim_expanded_features = self.conv2(dim_reduced_features)
return dim_expanded_features, dim_reduced_features
class Delf(tf.keras.Model):
"""Instantiates Keras DELF model using ResNet50 as backbone.
......@@ -95,7 +127,10 @@ class Delf(tf.keras.Model):
pooling='avg',
gem_power=3.0,
embedding_layer=False,
embedding_layer_dim=2048):
embedding_layer_dim=2048,
use_dim_reduction=False,
reduced_dimension=128,
dim_expand_channels=1024):
"""Initialization of DELF model.
Args:
......@@ -108,6 +143,14 @@ class Delf(tf.keras.Model):
embedding_layer: bool, whether to create an embedding layer (FC whitening
layer).
embedding_layer_dim: int, size of the embedding layer.
use_dim_reduction: Whether to integrate dimensionality reduction layers.
If True, extra layers are added to reduce the dimensionality of the
extracted features.
reduced_dimension: int, only used if use_dim_reduction is True. The output
dimension of the autoencoder layer.
dim_expand_channels: int, only used if use_dim_reduction is True. The
number of channels of the backbone block used. Default value 1024 is the
number of channels of backbone block 'block3'.
"""
super(Delf, self).__init__(name=name)
......@@ -126,6 +169,13 @@ class Delf(tf.keras.Model):
# Attention model.
self.attention = AttentionModel(name='attention')
# Autoencoder model.
self._use_dim_reduction = use_dim_reduction
if self._use_dim_reduction:
self.autoencoder = AutoencoderModel(reduced_dimension,
dim_expand_channels,
name='autoencoder')
def init_classifiers(self, num_classes, desc_classification=None):
"""Define classifiers for training backbone and attention models."""
self.num_classes = num_classes
......@@ -156,13 +206,24 @@ class Delf(tf.keras.Model):
# https://arxiv.org/abs/2001.05027.
block3 = backbone_blocks['block3'] # pytype: disable=key-error
block3 = tf.stop_gradient(block3)
if self._use_dim_reduction:
(dim_expanded_features, dim_reduced_features) = self.autoencoder(block3)
attn_prelogits, attn_scores, _ = self.attention(dim_expanded_features,
training=training)
else:
attn_prelogits, attn_scores, _ = self.attention(block3, training=training)
return desc_prelogits, attn_prelogits, attn_scores, backbone_blocks
dim_expanded_features = None
dim_reduced_features = None
return (desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
dim_expanded_features, dim_reduced_features)
def build_call(self, input_image, training=True):
(global_feature, _, attn_scores,
backbone_blocks) = self.global_and_local_forward_pass(input_image,
(global_feature, _, attn_scores, backbone_blocks, _,
dim_reduced_features) = self.global_and_local_forward_pass(input_image,
training)
if self._use_dim_reduction:
features = dim_reduced_features
else:
features = backbone_blocks['block3'] # pytype: disable=key-error
return global_feature, attn_scores, features
......
......@@ -88,7 +88,7 @@ class DelfTest(tf.test.TestCase, parameterized.TestCase):
per_example_loss, global_batch_size=batch_size)
with tf.GradientTape() as gradient_tape:
(desc_prelogits, attn_prelogits, _,
(desc_prelogits, attn_prelogits, _, _, _,
_) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier.
desc_logits = model.desc_classification(desc_prelogits)
......
......@@ -46,7 +46,10 @@ class Delg(delf_model.Delf):
gem_power=3.0,
embedding_layer_dim=2048,
scale_factor_init=45.25, # sqrt(2048)
arcface_margin=0.1):
arcface_margin=0.1,
use_dim_reduction=False,
reduced_dimension=128,
dim_expand_channels=1024):
"""Initialization of DELG model.
Args:
......@@ -56,6 +59,14 @@ class Delg(delf_model.Delf):
embedding_layer_dim : int, dimension of the embedding layer.
scale_factor_init: float.
arcface_margin: float, ArcFace margin.
use_dim_reduction: Whether to integrate dimensionality reduction layers.
If True, extra layers are added to reduce the dimensionality of the
extracted features.
reduced_dimension: Only used if use_dim_reduction is True, the output
dimension of the dim_reduction layer.
dim_expand_channels: Only used if use_dim_reduction is True, the
number of channels of the backbone block used. Default value 1024 is the
number of channels of backbone block 'block3'.
"""
logging.info('Creating Delg model, gem_power %d, embedding_layer_dim %d',
gem_power, embedding_layer_dim)
......@@ -64,7 +75,10 @@ class Delg(delf_model.Delf):
pooling='gem',
gem_power=gem_power,
embedding_layer=True,
embedding_layer_dim=embedding_layer_dim)
embedding_layer_dim=embedding_layer_dim,
use_dim_reduction=use_dim_reduction,
reduced_dimension=reduced_dimension,
dim_expand_channels=dim_expand_channels)
self._embedding_layer_dim = embedding_layer_dim
self._scale_factor_init = scale_factor_init
self._arcface_margin = arcface_margin
......
......@@ -36,10 +36,11 @@ from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
'Path to saved checkpoint.')
flags.DEFINE_string(
'ckpt_path', '/tmp/delf-logdir/delf-weights', 'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_boolean('delg_global_features', True,
flags.DEFINE_boolean(
'delg_global_features', True,
'Whether the model uses a DELG-like global feature head.')
flags.DEFINE_float(
'delg_gem_power', 3.0,
......@@ -52,8 +53,20 @@ flags.DEFINE_integer(
flags.DEFINE_boolean(
'block3_strides', True,
'Whether to apply strides after block3, used for local feature head.')
flags.DEFINE_float('iou', 1.0,
'IOU for non-max suppression used in local feature head.')
flags.DEFINE_float(
'iou', 1.0, 'IOU for non-max suppression used in local feature head.')
flags.DEFINE_boolean(
'use_autoencoder', True,
'Whether the exported model should use an autoencoder.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
class _ExtractModule(tf.Module):
......@@ -86,9 +99,17 @@ class _ExtractModule(tf.Module):
block3_strides=block3_strides,
name='DELG',
gem_power=delg_gem_power,
embedding_layer_dim=delg_embedding_layer_dim)
embedding_layer_dim=delg_embedding_layer_dim,
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
else:
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
self._model = delf_model.Delf(
block3_strides=block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path)
......
......@@ -35,12 +35,24 @@ from delf.python.training.model import export_model_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt_path', '/tmp/delf-logdir/delf-weights',
'Path to saved checkpoint.')
flags.DEFINE_string(
'ckpt_path', '/tmp/delf-logdir/delf-weights', 'Path to saved checkpoint.')
flags.DEFINE_string('export_path', None, 'Path where model will be exported.')
flags.DEFINE_boolean('block3_strides', False,
'Whether to apply strides after block3.')
flags.DEFINE_boolean(
'block3_strides', True, 'Whether to apply strides after block3.')
flags.DEFINE_float('iou', 1.0, 'IOU for non-max suppression.')
flags.DEFINE_boolean(
'use_autoencoder', True,
'Whether the exported model should use an autoencoder.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
class _ExtractModule(tf.Module):
......@@ -56,7 +68,12 @@ class _ExtractModule(tf.Module):
self._stride_factor = 2.0 if block3_strides else 1.0
self._iou = iou
# Setup the DELF model for extraction.
self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
self._model = delf_model.Delf(
block3_strides=block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
def LoadWeights(self, checkpoint_path):
self._model.load_weights(checkpoint_path)
......
......@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training script for DELF on Google Landmarks Dataset.
"""Training script for DELF/G on Google Landmarks Dataset.
Script to train DELF using classification loss on Google Landmarks Dataset
using MirroredStrategy to so it can run on multiple GPUs.
Uses classification loss, with MirroredStrategy, to support running on multiple
GPUs.
"""
from __future__ import absolute_import
......@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import os
import time
from absl import app
from absl import flags
......@@ -40,36 +41,63 @@ FLAGS = flags.FLAGS
flags.DEFINE_boolean('debug', False, 'Debug mode.')
flags.DEFINE_string('logdir', '/tmp/delf', 'WithTensorBoard logdir.')
flags.DEFINE_string('train_file_pattern', '/tmp/data/train*',
flags.DEFINE_string(
'train_file_pattern', '/tmp/data/train*',
'File pattern of training dataset files.')
flags.DEFINE_string('validation_file_pattern', '/tmp/data/validation*',
flags.DEFINE_string(
'validation_file_pattern', '/tmp/data/validation*',
'File pattern of validation dataset files.')
flags.DEFINE_enum(
'dataset_version', 'gld_v1', ['gld_v1', 'gld_v2', 'gld_v2_clean'],
'Google Landmarks dataset version, used to determine the'
'number of classes.')
'Google Landmarks dataset version, used to determine the number of '
'classes.')
flags.DEFINE_integer('seed', 0, 'Seed to training dataset.')
flags.DEFINE_float('initial_lr', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('batch_size', 32, 'Global batch size.')
flags.DEFINE_integer('max_iters', 500000, 'Maximum iterations.')
flags.DEFINE_boolean('block3_strides', True, 'Whether to use block3_strides.')
flags.DEFINE_boolean('use_augmentation', True,
'Whether to use ImageNet style augmentation.')
flags.DEFINE_boolean(
'use_augmentation', True, 'Whether to use ImageNet style augmentation.')
flags.DEFINE_string(
'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
flags.DEFINE_float('attention_loss_weight', 1.0,
flags.DEFINE_float(
'attention_loss_weight', 1.0,
'Weight to apply to the attention loss when calculating the '
'total loss of the model.')
flags.DEFINE_boolean('delg_global_features', False,
'Whether to train a DELG model.')
flags.DEFINE_float('delg_gem_power', 3.0, 'Power for Generalized Mean pooling.')
flags.DEFINE_integer('delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer).')
flags.DEFINE_float('delg_scale_factor_init', 45.25,
('Initial value of the scaling factor of the cosine logits.'
'The default value is sqrt(2048).'))
flags.DEFINE_float('delg_arcface_margin', 0.1, 'ArcFace margin.')
flags.DEFINE_boolean(
'delg_global_features', False, 'Whether to train a DELG model.')
flags.DEFINE_float(
'delg_gem_power', 3.0,
'Power for Generalized Mean pooling. Used only if '
'delg_global_features=True.')
flags.DEFINE_integer(
'delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer). Used only if'
'delg_global_features:True.')
flags.DEFINE_float(
'delg_scale_factor_init', 45.25,
'Initial value of the scaling factor of the cosine logits. The default '
'value is sqrt(2048). Used only if delg_global_features=True.')
flags.DEFINE_float(
'delg_arcface_margin', 0.1,
'ArcFace margin. Used only if delg_global_features=True.')
flags.DEFINE_integer('image_size', 321, 'Size of each image side to use.')
flags.DEFINE_boolean(
'use_autoencoder', True, 'Whether to train an autoencoder.')
flags.DEFINE_float(
'reconstruction_loss_weight', 10.0,
'Weight to apply to the reconstruction loss from the autoencoder when'
'calculating total loss of the model. Used only if use_autoencoder=True.')
flags.DEFINE_float(
'autoencoder_dimensions', 128,
'Number of dimensions of the autoencoder. Used only if'
'use_autoencoder=True.')
flags.DEFINE_float(
'local_feature_map_channels', 1024,
'Number of channels at backbone layer used for local feature extraction. '
'Default value 1024 is the number of channels of block3. Used only if'
'use_autoencoder=True.')
def _record_accuracy(metric, logits, labels):
......@@ -104,14 +132,23 @@ def _attention_summaries(scores, global_step):
def create_model(num_classes):
"""Define DELF model, and initialize classifiers."""
if FLAGS.delg_global_features:
model = delg_model.Delg(block3_strides=FLAGS.block3_strides,
model = delg_model.Delg(
block3_strides=FLAGS.block3_strides,
name='DELG',
gem_power=FLAGS.delg_gem_power,
embedding_layer_dim=FLAGS.delg_embedding_layer_dim,
scale_factor_init=FLAGS.delg_scale_factor_init,
arcface_margin=FLAGS.delg_arcface_margin)
arcface_margin=FLAGS.delg_arcface_margin,
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
else:
model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF')
model = delf_model.Delf(
block3_strides=FLAGS.block3_strides,
name='DELF',
use_dim_reduction=FLAGS.use_autoencoder,
reduced_dimension=FLAGS.autoencoder_dimensions,
dim_expand_channels=FLAGS.local_feature_map_channels)
model.init_classifiers(num_classes)
return model
......@@ -151,11 +188,11 @@ def main(argv):
max_iters = FLAGS.max_iters
global_batch_size = FLAGS.batch_size
image_size = 321
image_size = FLAGS.image_size
num_eval_batches = int(50000 / global_batch_size)
report_interval = 100
eval_interval = 1000
save_interval = 20000
save_interval = 1000
initial_lr = FLAGS.initial_lr
......@@ -167,7 +204,7 @@ def main(argv):
max_iters = 100
num_eval_batches = 1
save_interval = 1
report_interval = 1
report_interval = 10
# Determine the number of classes based on the version of the dataset.
gld_info = gld.GoogleLandmarksInfo()
......@@ -238,7 +275,12 @@ def main(argv):
# Setup checkpoint directory.
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint, checkpoint_prefix, max_to_keep=3)
checkpoint,
checkpoint_prefix,
max_to_keep=10,
keep_checkpoint_every_n_hours=3)
# Restores the checkpoint, if existing.
checkpoint.restore(manager.latest_checkpoint)
# ------------------------------------------------------------
# Train step to run on one GPU.
......@@ -248,13 +290,6 @@ def main(argv):
# Temporary workaround to avoid some corrupted labels.
labels = tf.clip_by_value(labels, 0, model.num_classes)
global_step = optimizer.iterations
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients.
......@@ -270,8 +305,8 @@ def main(argv):
# Record gradients and loss through backbone.
with tf.GradientTape() as gradient_tape:
# Make a forward pass to calculate prelogits.
(desc_prelogits, attn_prelogits, attn_scores,
backbone_blocks) = model.global_and_local_forward_pass(images)
(desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
dim_expanded_features, _) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier.
if FLAGS.delg_global_features:
......@@ -284,18 +319,36 @@ def main(argv):
attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits)
# Cumulate global loss and attention loss.
total_loss = desc_loss + FLAGS.attention_loss_weight * attn_loss
# Calculate reconstruction loss between the attention prelogits and the
# backbone.
if FLAGS.use_autoencoder:
block3 = tf.stop_gradient(backbone_blocks['block3'])
reconstruction_loss = tf.math.reduce_mean(
tf.keras.losses.MSE(block3, dim_expanded_features))
else:
reconstruction_loss = 0
# Perform backpropagation through the descriptor layer and attention layer
# together.
# Cumulate global loss, attention loss and reconstruction loss.
total_loss = (desc_loss
+ FLAGS.attention_loss_weight * attn_loss
+ FLAGS.reconstruction_loss_weight * reconstruction_loss)
# Perform backpropagation through the descriptor and attention layers
# together. Note that this will increment the number of iterations of
# "optimizer".
_backprop_loss(gradient_tape, total_loss, model.trainable_weights)
# Report scaling factor for cosine logits for a DELG model.
if FLAGS.delg_global_features:
tf.summary.scalar('desc/scale_factor', model.scale_factor,
step=global_step)
# Report attention and sparsity summaries.
# Step number, for summary purposes.
global_step = optimizer.iterations
# Input image-related summaries.
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
# Attention and sparsity summaries.
_attention_summaries(attn_scores, global_step)
activations_zero_fractions = {
'sparsity/%s' % k: tf.nn.zero_fraction(v)
......@@ -303,12 +356,17 @@ def main(argv):
}
for k, v in activations_zero_fractions.items():
tf.summary.scalar(k, v, step=global_step)
# Record descriptor train accuracy.
# Scaling factor summary for cosine logits for a DELG model.
if FLAGS.delg_global_features:
tf.summary.scalar(
'desc/scale_factor', model.scale_factor, step=global_step)
# Record train accuracies.
_record_accuracy(desc_train_accuracy, desc_logits, labels)
# Record attention train accuracy.
_record_accuracy(attn_train_accuracy, attn_logits, labels)
return desc_loss, attn_loss
return desc_loss, attn_loss, reconstruction_loss
# ------------------------------------------------------------
def validation_step(inputs):
......@@ -350,7 +408,7 @@ def main(argv):
def distributed_train_step(dataset_inputs):
"""Get the actual losses."""
# Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
desc_per_replica_loss, attn_per_replica_loss = (
desc_per_replica_loss, attn_per_replica_loss, recon_per_replica_loss = (
strategy.run(train_step, args=(dataset_inputs,)))
# Reduce over the replicas.
......@@ -358,8 +416,10 @@ def main(argv):
tf.distribute.ReduceOp.SUM, desc_per_replica_loss, axis=None)
attn_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, attn_per_replica_loss, axis=None)
recon_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, recon_per_replica_loss, axis=None)
return desc_global_loss, attn_global_loss
return desc_global_loss, attn_global_loss, recon_global_loss
@tf.function
def distributed_validation_step(dataset_inputs):
......@@ -368,15 +428,16 @@ def main(argv):
# ------------------------------------------------------------
# *** TRAIN LOOP ***
with summary_writer.as_default():
with tf.summary.record_if(
tf.math.equal(0, optimizer.iterations % report_interval)):
record_cond = lambda: tf.equal(optimizer.iterations % report_interval, 0)
with tf.summary.record_if(record_cond):
global_step_value = optimizer.iterations.numpy()
# TODO(dananghel): try to load pretrained weights at backbone creation.
# Load pretrained weights for ResNet50 trained on ImageNet.
if FLAGS.imagenet_checkpoint is not None:
if (FLAGS.imagenet_checkpoint is not None) and (not global_step_value):
logging.info('Attempting to load ImageNet pretrained weights.')
input_batch = next(train_iter)
_, _ = distributed_train_step(input_batch)
_, _, _ = distributed_train_step(input_batch)
model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
logging.info('Done.')
else:
......@@ -384,9 +445,9 @@ def main(argv):
if FLAGS.debug:
model.backbone.log_weights()
global_step_value = optimizer.iterations.numpy()
last_summary_step_value = None
last_summary_time = None
while global_step_value < max_iters:
# input_batch : images(b, h, w, c), labels(b,).
try:
input_batch = next(train_iter)
......@@ -396,24 +457,27 @@ def main(argv):
global_step_value)
break
# Set learning rate for optimizer to use.
# Set learning rate and run the training step over num_gpu gpus.
optimizer.learning_rate = _learning_rate_schedule(
optimizer.iterations.numpy(), max_iters, initial_lr)
desc_dist_loss, attn_dist_loss, recon_dist_loss = (
distributed_train_step(input_batch))
# Step number, to be used for summary/logging.
global_step = optimizer.iterations
global_step_value = global_step.numpy()
learning_rate = _learning_rate_schedule(global_step_value, max_iters,
initial_lr)
optimizer.learning_rate = learning_rate
# LR, losses and accuracies summaries.
tf.summary.scalar(
'learning_rate', optimizer.learning_rate, step=global_step)
# Run the training step over num_gpu gpus.
desc_dist_loss, attn_dist_loss = distributed_train_step(input_batch)
# Log losses and accuracies to tensorboard.
tf.summary.scalar(
'loss/desc/crossentropy', desc_dist_loss, step=global_step)
tf.summary.scalar(
'loss/attn/crossentropy', attn_dist_loss, step=global_step)
if FLAGS.use_autoencoder:
tf.summary.scalar(
'loss/recon/mse', recon_dist_loss, step=global_step)
tf.summary.scalar(
'train_accuracy/desc',
desc_train_accuracy.result(),
......@@ -423,6 +487,19 @@ def main(argv):
attn_train_accuracy.result(),
step=global_step)
# Summary for number of global steps taken per second.
current_time = time.time()
if (last_summary_step_value is not None and
last_summary_time is not None):
tf.summary.scalar(
'global_steps_per_sec',
(global_step_value - last_summary_step_value) /
(current_time - last_summary_time),
step=global_step)
if tf.summary.should_record_summaries().numpy():
last_summary_step_value = global_step_value
last_summary_time = current_time
# Print to console if running locally.
if FLAGS.debug:
if global_step_value % report_interval == 0:
......@@ -455,12 +532,14 @@ def main(argv):
print('Validation: desc:', desc_validation_result.numpy())
print(' : attn:', attn_validation_result.numpy())
# Save checkpoint once (each save_interval*n, n \in N) steps.
# Save checkpoint once (each save_interval*n, n \in N) steps, or if
# this is the last iteration.
# TODO(andrearaujo): save only in one of the two ways. They are
# identical, the only difference is that the manager adds some extra
# prefixes and variables (eg, optimizer variables).
if global_step_value % save_interval == 0:
save_path = manager.save()
if (global_step_value %
save_interval == 0) or (global_step_value >= max_iters):
save_path = manager.save(checkpoint_number=global_step_value)
logging.info('Saved (%d) at %s', global_step_value, save_path)
file_path = '%s/delf_weights' % FLAGS.logdir
......@@ -476,9 +555,6 @@ def main(argv):
desc_validation_accuracy.reset_states()
attn_validation_accuracy.reset_states()
if global_step.numpy() > max_iters:
break
logging.info('Finished training for %d steps.', max_iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment