Unverified Commit 85f52928 authored by Dan Anghel's avatar Dan Anghel Committed by GitHub
Browse files

Backpropagate global and attention layers together (#9335)



* Merged commit includes the following changes:
326369548  by Andre Araujo:

    Fix import issues.

--
326159826  by Andre Araujo:

    Changed the implementation of the cosine weights from Keras layer to tf.Variable to manually control for L2 normalization.

--
326139082  by Andre Araujo:

    Support local feature matching using ratio test.

    To allow for easily choosing which matching type to use, we rename a flag/argument and modify all related files to avoid breakages.

    Also include a small change when computing nearest neighbors for geometric matching, to parallelize computation, which saves a little bit of time during execution (argument "n_jobs=-1").

--
326119848  by Andre Araujo:

    Option to measure DELG latency taking binarization into account.

--
324316608  by Andre Araujo:

    DELG global features training.

--
323693131  by Andre Araujo:

    PY3 conversion for delf public lib.

--
321046157  by Andre Araujo:

    Purely Google refactor

--

PiperOrigin-RevId: 326369548

* Added export of delg_model module.

* README update to explain training DELG global features head

* Added guidelines for DELF hyperparameter values

* Fixed typo

* Added mention about remaining training flags.

* Merged commit includes the following changes:
334723489  by Andre Araujo:

    Backpropagate global and attention layers together.

--
334228310  by Andre Araujo:

    Enable scaling of local feature locations to the resized resolution.

--

PiperOrigin-RevId: 334723489
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent f97074b0
......@@ -49,6 +49,12 @@ message DelfLocalFeatureConfig {
// PCA parameters for DELF local feature. This is used only if use_pca is
// true.
optional DelfPcaParameters pca_parameters = 6;
// If true, the returned keypoint locations are grounded to coordinates of the
// resized image used for extraction. If false (default), the returned
// keypoint locations are grounded to coordinates of the original image that
// is fed into feature extraction.
optional bool use_resized_coordinates = 7 [default = false];
}
message DelfGlobalFeatureConfig {
......
......@@ -245,6 +245,7 @@ def MakeExtractor(config):
feature_extractor.DelfFeaturePostProcessing(
boxes, raw_local_descriptors, config.delf_local_config.use_pca,
local_pca_parameters))
if not config.delf_local_config.use_resized_coordinates:
locations /= scale_factors
extracted_features.update({
......
......@@ -137,28 +137,34 @@ class Delf(tf.keras.Model):
self.attn_classification = layers.Dense(
num_classes, activation=None, kernel_regularizer=None, name='att_fc')
@property
def desc_trainable_weights(self):
"""Weights to optimize for descriptor fine tuning."""
return (self.backbone.trainable_weights +
self.desc_classification.trainable_weights)
@property
def attn_trainable_weights(self):
"""Weights to optimize for attention model training."""
return (self.attention.trainable_weights +
self.attn_classification.trainable_weights)
def global_and_local_forward_pass(self, images, training=True):
"""Run a forward to calculate global descriptor and attention prelogits.
def build_call(self, input_image, training=True):
blocks = {}
global_feature = self.backbone.build_call(
input_image, intermediates_dict=blocks, training=training)
Args:
images: Tensor containing the dataset on which to run the forward pass.
training: Indicator of wether the forward pass is running in training mode
or not.
features = blocks['block3'] # pytype: disable=key-error
_, probs, _ = self.attention(features, training=training)
Returns:
Global descriptor prelogits, attention prelogits, attention scores,
backbone weights.
"""
backbone_blocks = {}
desc_prelogits = self.backbone.build_call(
images, intermediates_dict=backbone_blocks, training=training)
# Prevent gradients from propagating into the backbone. See DELG paper:
# https://arxiv.org/abs/2001.05027.
block3 = backbone_blocks['block3'] # pytype: disable=key-error
block3 = tf.stop_gradient(block3)
attn_prelogits, attn_scores, _ = self.attention(block3, training=training)
return desc_prelogits, attn_prelogits, attn_scores, backbone_blocks
return global_feature, probs, features
def build_call(self, input_image, training=True):
(global_feature, _, attn_scores,
backbone_blocks) = self.global_and_local_forward_pass(input_image,
training)
features = backbone_blocks['block3'] # pytype: disable=key-error
return global_feature, attn_scores, features
def call(self, input_image, training=True):
_, probs, features = self.build_call(input_image, training=training)
......
......@@ -87,28 +87,21 @@ class DelfTest(tf.test.TestCase, parameterized.TestCase):
return tf.nn.compute_average_loss(
per_example_loss, global_batch_size=batch_size)
with tf.GradientTape() as desc_tape:
blocks = {}
desc_prelogits = model.backbone(
images, intermediates_dict=blocks, training=False)
desc_logits = model.desc_classification(desc_prelogits)
with tf.GradientTape() as gradient_tape:
(desc_prelogits, attn_prelogits, _,
_) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier.
desc_logits = model.desc_classification(desc_prelogits)
desc_loss = compute_loss(labels, desc_logits)
gradients = desc_tape.gradient(desc_loss, model.desc_trainable_weights)
clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
optimizer.apply_gradients(zip(clipped, model.desc_trainable_weights))
with tf.GradientTape() as attn_tape:
block3 = blocks['block3']
block3 = tf.stop_gradient(block3)
attn_prelogits, _, _ = model.attention(block3, training=True)
# Calculate attention loss by applying the attention block classifier.
attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits)
gradients = attn_tape.gradient(attn_loss, model.attn_trainable_weights)
# Cumulate global loss and attention loss and backpropagate through the
# descriptor layer and attention layer together.
total_loss = desc_loss + attn_loss
gradients = gradient_tape.gradient(total_loss, model.trainable_weights)
clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
optimizer.apply_gradients(zip(clipped, model.attn_trainable_weights))
optimizer.apply_gradients(zip(clipped, model.trainable_weights))
if __name__ == '__main__':
......
......@@ -58,6 +58,9 @@ flags.DEFINE_boolean('use_augmentation', True,
flags.DEFINE_string(
'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
flags.DEFINE_float('attention_loss_weight', 1.0,
'Weight to apply to the attention loss when calculating the '
'total loss of the model.')
flags.DEFINE_boolean('delg_global_features', False,
'Whether to train a DELG model.')
flags.DEFINE_float('delg_gem_power', 3.0, 'Power for Generalized Mean pooling.')
......@@ -252,8 +255,6 @@ def main(argv):
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
# TODO(andrearaujo): we should try to unify the backprop into a single
# function, instead of applying once to descriptor then to attention.
def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients.
......@@ -267,57 +268,45 @@ def main(argv):
optimizer.apply_gradients(zip(clipped, weights))
# Record gradients and loss through backbone.
with tf.GradientTape() as desc_tape:
with tf.GradientTape() as gradient_tape:
# Make a forward pass to calculate prelogits.
(desc_prelogits, attn_prelogits, attn_scores,
backbone_blocks) = model.global_and_local_forward_pass(images)
blocks = {}
prelogits = model.backbone(
images, intermediates_dict=blocks, training=True)
# Calculate global loss by applying the descriptor classifier.
if FLAGS.delg_global_features:
desc_logits = model.desc_classification(desc_prelogits, labels)
else:
desc_logits = model.desc_classification(desc_prelogits)
desc_loss = compute_loss(labels, desc_logits)
# Calculate attention loss by applying the attention block classifier.
attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits)
# Cumulate global loss and attention loss.
total_loss = desc_loss + FLAGS.attention_loss_weight * attn_loss
# Perform backpropagation through the descriptor layer and attention layer
# together.
_backprop_loss(gradient_tape, total_loss, model.trainable_weights)
# Report sparsity.
# Report scaling factor for cosine logits for a DELG model.
if FLAGS.delg_global_features:
tf.summary.scalar('desc/scale_factor', model.scale_factor,
step=global_step)
# Report attention and sparsity summaries.
_attention_summaries(attn_scores, global_step)
activations_zero_fractions = {
'sparsity/%s' % k: tf.nn.zero_fraction(v)
for k, v in blocks.items()
for k, v in backbone_blocks.items()
}
for k, v in activations_zero_fractions.items():
tf.summary.scalar(k, v, step=global_step)
# Apply descriptor classifier and report scale factor.
if FLAGS.delg_global_features:
logits = model.desc_classification(prelogits, labels)
tf.summary.scalar('desc/scale_factor', model.scale_factor,
step=global_step)
else:
logits = model.desc_classification(prelogits)
desc_loss = compute_loss(labels, logits)
# Backprop only through backbone weights.
_backprop_loss(desc_tape, desc_loss, model.desc_trainable_weights)
# Record descriptor train accuracy.
_record_accuracy(desc_train_accuracy, logits, labels)
# Record gradients and loss through attention block.
with tf.GradientTape() as attn_tape:
block3 = blocks['block3'] # pytype: disable=key-error
# Stopping gradients according to DELG paper:
# (https://arxiv.org/abs/2001.05027).
block3 = tf.stop_gradient(block3)
prelogits, scores, _ = model.attention(block3, training=True)
_attention_summaries(scores, global_step)
# Apply attention block classifier.
logits = model.attn_classification(prelogits)
attn_loss = compute_loss(labels, logits)
# Backprop only through attention weights.
_backprop_loss(attn_tape, attn_loss, model.attn_trainable_weights)
_record_accuracy(desc_train_accuracy, desc_logits, labels)
# Record attention train accuracy.
_record_accuracy(attn_train_accuracy, logits, labels)
_record_accuracy(attn_train_accuracy, attn_logits, labels)
return desc_loss, attn_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment