Unverified Commit 7ff3ebcc authored by Daniel Ron's avatar Daniel Ron Committed by GitHub
Browse files

Fix attention application in DELG (#9906)

* Fix attention application in DELG

* Adding DELG unit tests

* Formatting for review

* Formatting for review

* Formatting for review
parent fcd681d2
...@@ -35,6 +35,8 @@ class AttentionModel(tf.keras.Model): ...@@ -35,6 +35,8 @@ class AttentionModel(tf.keras.Model):
Uses two [kernel_size x kernel_size] convolutions and softplus as activation Uses two [kernel_size x kernel_size] convolutions and softplus as activation
to compute an attention map with the same resolution as the featuremap. to compute an attention map with the same resolution as the featuremap.
Features l2-normalized and aggregated using attention probabilites as weights. Features l2-normalized and aggregated using attention probabilites as weights.
The features (targets) to be aggregated can be the input featuremap, or a
different one with the same resolution.
""" """
def __init__(self, kernel_size=1, decay=_DECAY, name='attention'): def __init__(self, kernel_size=1, decay=_DECAY, name='attention'):
...@@ -65,7 +67,7 @@ class AttentionModel(tf.keras.Model): ...@@ -65,7 +67,7 @@ class AttentionModel(tf.keras.Model):
name='attn_conv2') name='attn_conv2')
self.activation_layer = layers.Activation('softplus') self.activation_layer = layers.Activation('softplus')
def call(self, inputs, training=True): def call(self, inputs, targets=None, training=True):
x = self.conv1(inputs) x = self.conv1(inputs)
x = self.bn_conv1(x, training=training) x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x) x = tf.nn.relu(x)
...@@ -73,9 +75,13 @@ class AttentionModel(tf.keras.Model): ...@@ -73,9 +75,13 @@ class AttentionModel(tf.keras.Model):
score = self.conv2(x) score = self.conv2(x)
prob = self.activation_layer(score) prob = self.activation_layer(score)
# Aggregate inputs if targets is None.
if targets is None:
targets = inputs
# L2-normalize the featuremap before pooling. # L2-normalize the featuremap before pooling.
inputs = tf.nn.l2_normalize(inputs, axis=-1) targets = tf.nn.l2_normalize(targets, axis=-1)
feat = tf.reduce_mean(tf.multiply(inputs, prob), [1, 2], keepdims=False) feat = tf.reduce_mean(tf.multiply(targets, prob), [1, 2], keepdims=False)
return feat, prob, score return feat, prob, score
...@@ -208,8 +214,10 @@ class Delf(tf.keras.Model): ...@@ -208,8 +214,10 @@ class Delf(tf.keras.Model):
block3 = tf.stop_gradient(block3) block3 = tf.stop_gradient(block3)
if self._use_dim_reduction: if self._use_dim_reduction:
(dim_expanded_features, dim_reduced_features) = self.autoencoder(block3) (dim_expanded_features, dim_reduced_features) = self.autoencoder(block3)
attn_prelogits, attn_scores, _ = self.attention(dim_expanded_features, attn_prelogits, attn_scores, _ = self.attention(
training=training) block3,
targets=dim_expanded_features,
training=training)
else: else:
attn_prelogits, attn_scores, _ = self.attention(block3, training=training) attn_prelogits, attn_scores, _ = self.attention(block3, training=training)
dim_expanded_features = None dim_expanded_features = None
......
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the DELG model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from delf.python.training.model import delg_model
class DelgTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('block3_stridesTrue', True),
('block3_stridesFalse', False),
)
def test_forward_pass(self, block3_strides):
image_size = 321
num_classes = 1000
batch_size = 2
input_shape = (batch_size, image_size, image_size, 3)
local_feature_dim = 64
feature_map_size = image_size // 16 # reduction factor for resnet50.
if block3_strides:
feature_map_size //= 2
model = delg_model.Delg(block3_strides=block3_strides,
use_dim_reduction=True,
reduced_dimension=local_feature_dim)
model.init_classifiers(num_classes)
images = tf.random.uniform(input_shape, minval=-1.0, maxval=1.0, seed=0)
# Run a complete forward pass of the model.
global_feature, attn_scores, local_features = model.build_call(images)
self.assertAllEqual(global_feature.shape, (batch_size, 2048))
self.assertAllEqual(
attn_scores.shape,
(batch_size, feature_map_size, feature_map_size, 1))
self.assertAllEqual(
local_features.shape,
(batch_size, feature_map_size, feature_map_size, local_feature_dim))
@parameterized.named_parameters(
('block3_stridesTrue', True),
('block3_stridesFalse', False),
)
def test_build_model(self, block3_strides):
image_size = 321
num_classes = 1000
batch_size = 2
input_shape = (batch_size, image_size, image_size, 3)
model = delg_model.Delg(
block3_strides=block3_strides,
use_dim_reduction=True)
model.init_classifiers(num_classes)
images = tf.random.uniform(input_shape, minval=-1.0, maxval=1.0, seed=0)
labels = tf.random.uniform((batch_size,),
minval=0,
maxval=model.num_classes - 1,
dtype=tf.int64)
blocks = {}
desc_prelogits = model.backbone(
images, intermediates_dict=blocks, training=False)
desc_logits = model.desc_classification(desc_prelogits, labels)
self.assertAllEqual(desc_prelogits.shape, (batch_size, 2048))
self.assertAllEqual(desc_logits.shape, (batch_size, num_classes))
features = blocks['block3']
attn_prelogits, _, _ = model.attention(features)
attn_logits = model.attn_classification(attn_prelogits)
self.assertAllEqual(attn_prelogits.shape, (batch_size, 1024))
self.assertAllEqual(attn_logits.shape, (batch_size, num_classes))
@parameterized.named_parameters(
('block3_stridesTrue', True),
('block3_stridesFalse', False),
)
def test_train_step(self, block3_strides):
image_size = 321
num_classes = 1000
batch_size = 2
clip_val = 10.0
input_shape = (batch_size, image_size, image_size, 3)
model = delg_model.Delg(
block3_strides=block3_strides,
use_dim_reduction=True)
model.init_classifiers(num_classes)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
images = tf.random.uniform(input_shape, minval=0.0, maxval=1.0, seed=0)
labels = tf.random.uniform((batch_size,),
minval=0,
maxval=model.num_classes - 1,
dtype=tf.int64)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(
per_example_loss, global_batch_size=batch_size)
with tf.GradientTape() as gradient_tape:
(desc_prelogits, attn_prelogits, _, backbone_blocks,
dim_expanded_features, _) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier.
desc_logits = model.desc_classification(desc_prelogits, labels)
desc_loss = compute_loss(labels, desc_logits)
# Calculate attention loss by applying the attention block classifier.
attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits)
# Calculate reconstruction loss between the attention prelogits and the
# backbone.
block3 = tf.stop_gradient(backbone_blocks['block3'])
reconstruction_loss = tf.math.reduce_mean(
tf.keras.losses.MSE(block3, dim_expanded_features))
# Cumulate global loss and attention loss and backpropagate through the
# descriptor layer and attention layer together.
total_loss = desc_loss + attn_loss + reconstruction_loss
gradients = gradient_tape.gradient(total_loss, model.trainable_weights)
clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
optimizer.apply_gradients(zip(clipped, model.trainable_weights))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment