Unverified Commit a003b7c1 authored by Dan Anghel's avatar Dan Anghel Committed by GitHub
Browse files

DELF updates (#9095)



* Merged commit includes the following changes:
326369548  by Andre Araujo:

    Fix import issues.

--
326159826  by Andre Araujo:

    Changed the implementation of the cosine weights from Keras layer to tf.Variable to manually control for L2 normalization.

--
326139082  by Andre Araujo:

    Support local feature matching using ratio test.

    To allow for easily choosing which matching type to use, we rename a flag/argument and modify all related files to avoid breakages.

    Also include a small change when computing nearest neighbors for geometric matching, to parallelize computation, which saves a little bit of time during execution (argument "n_jobs=-1").

--
326119848  by Andre Araujo:

    Option to measure DELG latency taking binarization into account.

--
324316608  by Andre Araujo:

    DELG global features training.

--
323693131  by Andre Araujo:

    PY3 conversion for delf public lib.

--
321046157  by Andre Araujo:

    Purely Google refactor

--

PiperOrigin-RevId: 326369548

* Added export of delg_model module.
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent b4c4a534
# Lint as: python3
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -42,6 +42,11 @@ flags.DEFINE_string('list_images_path', '/tmp/list_images.txt',
'Path to list of images whose features will be extracted.')
flags.DEFINE_integer('repeat_per_image', 10,
'Number of times to repeat extraction per image.')
flags.DEFINE_boolean(
'binary_local_features', False,
'Whether to binarize local features after extraction, and take this extra '
'latency into account. This should only be used if use_local_features is '
'set in the input DelfConfig from `delf_config_path`.')
# Pace to report extraction log.
_STATUS_CHECK_ITERATIONS = 100
......@@ -103,6 +108,12 @@ def main(argv):
# Extract and save features.
extracted_features = extractor_fn(im)
# Binarize local features, if desired (and if there are local features).
if (config.use_local_features and FLAGS.binary_local_features and
extracted_features['local_features']['attention'].size):
packed_descriptors = np.packbits(
extracted_features['local_features']['descriptors'] > 0, axis=1)
if __name__ == '__main__':
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -44,15 +45,19 @@ flags.DEFINE_boolean(
'If True, performs re-ranking using local feature-based geometric '
'verification.')
flags.DEFINE_float(
'local_feature_distance_threshold', 1.0,
'local_descriptor_matching_threshold', 1.0,
'Optional, only used if `use_geometric_verification` is True. '
'Distance threshold below which a pair of local descriptors is considered '
'Threshold below which a pair of local descriptors is considered '
'a potential match, and will be fed into RANSAC.')
flags.DEFINE_float(
'ransac_residual_threshold', 20.0,
'Optional, only used if `use_geometric_verification` is True. '
'Residual error threshold for considering matches as inliers, used in '
'RANSAC algorithm.')
flags.DEFINE_boolean(
'use_ratio_test', False,
'Optional, only used if `use_geometric_verification` is True. '
'Whether to use ratio test for local feature matching.')
flags.DEFINE_string(
'output_dir', '/tmp/retrieval',
'Directory where retrieval output will be written to. A file containing '
......@@ -152,8 +157,10 @@ def main(argv):
junk_ids=set(medium_ground_truth[i]['junk']),
local_feature_extension=_DELG_LOCAL_EXTENSION,
ransac_seed=0,
feature_distance_threshold=FLAGS.local_feature_distance_threshold,
ransac_residual_threshold=FLAGS.ransac_residual_threshold)
descriptor_matching_threshold=FLAGS
.local_descriptor_matching_threshold,
ransac_residual_threshold=FLAGS.ransac_residual_threshold,
use_ratio_test=FLAGS.use_ratio_test)
hard_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
input_ranks=ranks_before_gv[i],
initial_scores=similarities,
......@@ -164,8 +171,10 @@ def main(argv):
junk_ids=set(hard_ground_truth[i]['junk']),
local_feature_extension=_DELG_LOCAL_EXTENSION,
ransac_seed=0,
feature_distance_threshold=FLAGS.local_feature_distance_threshold,
ransac_residual_threshold=FLAGS.ransac_residual_threshold)
descriptor_matching_threshold=FLAGS
.local_descriptor_matching_threshold,
ransac_residual_threshold=FLAGS.ransac_residual_threshold,
use_ratio_test=FLAGS.use_ratio_test)
elapsed = (time.time() - start)
print('done! Retrieval for query %d took %f seconds' % (i, elapsed))
......
# Lint as: python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -47,12 +47,13 @@ def MatchFeatures(query_locations,
index_image_locations,
index_image_descriptors,
ransac_seed=None,
feature_distance_threshold=0.9,
descriptor_matching_threshold=0.9,
ransac_residual_threshold=10.0,
query_im_array=None,
index_im_array=None,
query_im_scale_factors=None,
index_im_scale_factors=None):
index_im_scale_factors=None,
use_ratio_test=False):
"""Matches local features using geometric verification.
First, finds putative local feature matches by matching `query_descriptors`
......@@ -70,8 +71,10 @@ def MatchFeatures(query_locations,
index_image_descriptors: Descriptors of local features for index image.
NumPy array of shape [#index_image_features, depth].
ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
feature_distance_threshold: Distance threshold below which a pair of
features is considered a potential match, and will be fed into RANSAC.
descriptor_matching_threshold: Threshold below which a pair of local
descriptors is considered a potential match, and will be fed into RANSAC.
If use_ratio_test==False, this is a simple distance threshold. If
use_ratio_test==True, this is Lowe's ratio test threshold.
ransac_residual_threshold: Residual error threshold for considering matches
as inliers, used in RANSAC algorithm.
query_im_array: Optional. If not None, contains a NumPy array with the query
......@@ -83,6 +86,8 @@ def MatchFeatures(query_locations,
(ie, feature locations are not scaled).
index_im_scale_factors: Optional. Same as `query_im_scale_factors`, but for
index image.
use_ratio_test: If True, descriptor matching is performed via ratio test,
instead of distance-based threshold.
Returns:
score: Number of inliers of match. If no match is found, returns 0.
......@@ -105,10 +110,26 @@ def MatchFeatures(query_locations,
'Local feature dimensionality is not consistent for query and index '
'images.')
# Find nearest-neighbor matches using a KD tree.
# Construct KD-tree used to find nearest neighbors.
index_image_tree = spatial.cKDTree(index_image_descriptors)
if use_ratio_test:
distances, indices = index_image_tree.query(
query_descriptors, k=2, n_jobs=-1)
query_locations_to_use = np.array([
query_locations[i,]
for i in range(num_features_query)
if distances[i][0] < descriptor_matching_threshold * distances[i][1]
])
index_image_locations_to_use = np.array([
index_image_locations[indices[i][0],]
for i in range(num_features_query)
if distances[i][0] < descriptor_matching_threshold * distances[i][1]
])
else:
_, indices = index_image_tree.query(
query_descriptors, distance_upper_bound=feature_distance_threshold)
query_descriptors,
distance_upper_bound=descriptor_matching_threshold,
n_jobs=-1)
# Select feature locations for putative matches.
query_locations_to_use = np.array([
......@@ -175,8 +196,9 @@ def RerankByGeometricVerification(input_ranks,
junk_ids,
local_feature_extension=_DELF_EXTENSION,
ransac_seed=None,
feature_distance_threshold=0.9,
ransac_residual_threshold=10.0):
descriptor_matching_threshold=0.9,
ransac_residual_threshold=10.0,
use_ratio_test=False):
"""Re-ranks retrieval results using geometric verification.
Args:
......@@ -195,10 +217,11 @@ def RerankByGeometricVerification(input_ranks,
local_feature_extension: String, extension to use for loading local feature
files.
ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
feature_distance_threshold: Distance threshold below which a pair of local
features is considered a potential match, and will be fed into RANSAC.
descriptor_matching_threshold: Threshold used for local descriptor matching.
ransac_residual_threshold: Residual error threshold for considering matches
as inliers, used in RANSAC algorithm.
use_ratio_test: If True, descriptor matching is performed via ratio test,
instead of distance-based threshold.
Returns:
output_ranks: 1D NumPy array with index image indices, sorted from the most
......@@ -258,8 +281,9 @@ def RerankByGeometricVerification(input_ranks,
index_image_locations,
index_image_descriptors,
ransac_seed=ransac_seed,
feature_distance_threshold=feature_distance_threshold,
ransac_residual_threshold=ransac_residual_threshold)
descriptor_matching_threshold=descriptor_matching_threshold,
ransac_residual_threshold=ransac_residual_threshold,
use_ratio_test=use_ratio_test)
# Sort based on (inliers_score, initial_score).
def _InliersInitialScoresSorting(k):
......
# Lint as: python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from delf.python.training.model import delf_model
from delf.python.training.model import delg_model
from delf.python.training.model import export_model_utils
from delf.python.training.model import resnet50
# pylint: enable=unused-import
......@@ -89,12 +89,20 @@ class Delf(tf.keras.Model):
from conv_4 are used to compute an attention map of the same resolution.
"""
def __init__(self, block3_strides=True, name='DELF'):
def __init__(self, block3_strides=True, name='DELF', pooling='avg',
gem_power=3.0, embedding_layer=False, embedding_layer_dim=2048):
"""Initialization of DELF model.
Args:
block3_strides: bool, whether to add strides to the output of block3.
name: str, name to identify model.
pooling: str, pooling mode for global feature extraction; possible values
are 'None', 'avg', 'max', 'gem.'
gem_power: float, GeM power for GeM pooling. Only used if
pooling == 'gem'.
embedding_layer: bool, whether to create an embedding layer (FC whitening
layer).
embedding_layer_dim: int, size of the embedding layer.
"""
super(Delf, self).__init__(name=name)
......@@ -103,31 +111,38 @@ class Delf(tf.keras.Model):
'channels_last',
name='backbone',
include_top=False,
pooling='avg',
pooling=pooling,
block3_strides=block3_strides,
average_pooling=False)
average_pooling=False,
gem_power=gem_power,
embedding_layer=embedding_layer,
embedding_layer_dim=embedding_layer_dim)
# Attention model.
self.attention = AttentionModel(name='attention')
# Define classifiers for training backbone and attention models.
def init_classifiers(self, num_classes):
def init_classifiers(self, num_classes, desc_classification=None):
"""Define classifiers for training backbone and attention models."""
self.num_classes = num_classes
self.desc_classification = layers.Dense(
num_classes, activation=None, kernel_regularizer=None, name='desc_fc')
if desc_classification is None:
self.desc_classification = layers.Dense(num_classes,
activation=None,
kernel_regularizer=None,
name='desc_fc')
else:
self.desc_classification = desc_classification
self.attn_classification = layers.Dense(
num_classes, activation=None, kernel_regularizer=None, name='att_fc')
# Weights to optimize for descriptor fine tuning.
@property
def desc_trainable_weights(self):
"""Weights to optimize for descriptor fine tuning."""
return (self.backbone.trainable_weights +
self.desc_classification.trainable_weights)
# Weights to optimize for attention model training.
@property
def attn_trainable_weights(self):
"""Weights to optimize for attention model training."""
return (self.attention.trainable_weights +
self.attn_classification.trainable_weights)
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DELG model implementation based on the following paper.
Unifying Deep Local and Global Features for Image Search
https://arxiv.org/abs/2001.05027
"""
import functools
import math
from absl import logging
import tensorflow as tf
from delf.python.training.model import delf_model
layers = tf.keras.layers
class Delg(delf_model.Delf):
"""Instantiates Keras DELG model using ResNet50 as backbone.
This class implements the [DELG](https://arxiv.org/abs/2001.05027) model for
extracting local and global features from images. The same attention layer
is trained as in the DELF model. In addition, the extraction of global
features is trained using GeMPooling, a FC whitening layer also called
"embedding layer" and ArcFace loss.
"""
def __init__(self,
block3_strides=True,
name='DELG',
gem_power=3.0,
embedding_layer_dim=2048,
scale_factor_init=45.25, # sqrt(2048)
arcface_margin=0.1):
"""Initialization of DELG model.
Args:
block3_strides: bool, whether to add strides to the output of block3.
name: str, name to identify model.
gem_power: float, GeM power parameter.
embedding_layer_dim : int, dimension of the embedding layer.
scale_factor_init: float.
arcface_margin: float, ArcFace margin.
"""
logging.info('Creating Delg model, gem_power %d, embedding_layer_dim %d',
gem_power, embedding_layer_dim)
super(Delg, self).__init__(block3_strides=block3_strides,
name=name,
pooling='gem',
gem_power=gem_power,
embedding_layer=True,
embedding_layer_dim=embedding_layer_dim)
self._embedding_layer_dim = embedding_layer_dim
self._scale_factor_init = scale_factor_init
self._arcface_margin = arcface_margin
def init_classifiers(self, num_classes):
"""Define classifiers for training backbone and attention models."""
logging.info('Initializing Delg backbone and attention models classifiers')
backbone_classifier_func = self._create_backbone_classifier(num_classes)
super(Delg, self).init_classifiers(
num_classes,
desc_classification=backbone_classifier_func)
def _create_backbone_classifier(self, num_classes):
"""Define the classifier for training the backbone model."""
logging.info('Creating cosine classifier')
self.cosine_weights = tf.Variable(
initial_value=tf.initializers.GlorotUniform()(
shape=[self._embedding_layer_dim, num_classes]),
name='cosine_weights',
trainable=True)
self.scale_factor = tf.Variable(self._scale_factor_init,
name='scale_factor',
trainable=False)
classifier_func = functools.partial(cosine_classifier_logits,
num_classes=num_classes,
cosine_weights=self.cosine_weights,
scale_factor=self.scale_factor,
arcface_margin=self._arcface_margin)
classifier_func.trainable_weights = [self.cosine_weights]
return classifier_func
def cosine_classifier_logits(prelogits,
labels,
num_classes,
cosine_weights,
scale_factor,
arcface_margin,
training=True):
"""Compute cosine classifier logits using ArFace margin.
Args:
prelogits: float tensor of shape [batch_size, 1, 1, embedding_layer_dim].
labels: int tensor of shape [batch_size].
num_classes: int, number of classes.
cosine_weights: float tensor of shape [embedding_layer_dim, num_classes].
scale_factor: float.
arcface_margin: float. Only used if greater than zero, and training is True.
training: bool, True if training, False if eval.
Returns:
logits: Float tensor [batch_size, num_classes].
"""
# Reshape from [batch_size, 1, 1, depth] to [batch_size, depth].
squeezed_prelogits = tf.squeeze(prelogits, [1, 2])
# L2-normalize prelogits, then obtain cosine similarity.
normalized_prelogits = tf.math.l2_normalize(squeezed_prelogits, axis=1)
normalized_weights = tf.math.l2_normalize(cosine_weights, axis=0)
cosine_sim = tf.matmul(normalized_prelogits, normalized_weights)
# Optionally use ArcFace margin.
if training and arcface_margin > 0.0:
# Reshape labels tensor from [batch_size] to [batch_size, num_classes].
one_hot_labels = tf.one_hot(labels, num_classes)
cosine_sim = apply_arcface_margin(cosine_sim,
one_hot_labels,
arcface_margin)
# Apply the scale factor to logits and return.
logits = scale_factor * cosine_sim
return logits
def apply_arcface_margin(cosine_sim, one_hot_labels, arcface_margin):
"""Applies ArcFace margin to cosine similarity inputs.
For a reference, see https://arxiv.org/pdf/1801.07698.pdf. ArFace margin is
applied to angles from correct classes (as per the ArcFace paper), and only
if they are <= (pi - margin). Otherwise, applying the margin may actually
improve their cosine similarity.
Args:
cosine_sim: float tensor with shape [batch_size, num_classes].
one_hot_labels: int tensor with shape [batch_size, num_classes].
arcface_margin: float.
Returns:
cosine_sim_with_margin: Float tensor with shape [batch_size, num_classes].
"""
theta = tf.acos(cosine_sim, name='acos')
selected_labels = tf.where(tf.greater(theta, math.pi - arcface_margin),
tf.zeros_like(one_hot_labels),
one_hot_labels,
name='selected_labels')
final_theta = tf.where(tf.cast(selected_labels, dtype=tf.bool),
theta + arcface_margin,
theta,
name='final_theta')
return tf.cos(final_theta, name='cosine_sim_with_margin')
......@@ -183,13 +183,16 @@ class ResNet50(tf.keras.Model):
output of the last convolutional layer. 'avg' means that global average
pooling will be applied to the output of the last convolutional layer, and
thus the output of the model will be a 2D tensor. 'max' means that global
max pooling will be applied.
max pooling will be applied. 'gem' means GeM pooling will be applied.
block3_strides: whether to add a stride of 2 to block3 to make it compatible
with tf.slim ResNet implementation.
average_pooling: whether to do average pooling of block4 features before
global pooling.
classes: optional number of classes to classify images into, only to be
specified if `include_top` is True.
gem_power: GeM power for GeM pooling. Only used if pooling == 'gem'.
embedding_layer: whether to create an embedding layer (FC whitening layer).
embedding_layer_dim: size of the embedding layer.
Raises:
ValueError: in case of invalid argument for data_format.
......@@ -202,7 +205,10 @@ class ResNet50(tf.keras.Model):
pooling=None,
block3_strides=False,
average_pooling=True,
classes=1000):
classes=1000,
gem_power=3.0,
embedding_layer=False,
embedding_layer_dim=2048):
super(ResNet50, self).__init__(name=name)
valid_channel_values = ('channels_first', 'channels_last')
......@@ -286,8 +292,19 @@ class ResNet50(tf.keras.Model):
elif pooling == 'max':
self.global_pooling = functools.partial(
tf.reduce_max, axis=reduction_indices, keepdims=False)
elif pooling == 'gem':
logging.info('Adding GeMPooling layer with power %f', gem_power)
self.global_pooling = functools.partial(
gem_pooling, axis=reduction_indices, power=gem_power)
else:
self.global_pooling = None
if embedding_layer:
logging.info('Adding embedding layer with dimension %d',
embedding_layer_dim)
self.embedding_layer = layers.Dense(embedding_layer_dim,
name='embedding_layer')
else:
self.embedding_layer = None
def build_call(self, inputs, training=True, intermediates_dict=None):
"""Building the ResNet50 model.
......@@ -358,7 +375,10 @@ class ResNet50(tf.keras.Model):
if self.include_top:
return self.fc1000(self.flatten(x))
elif self.global_pooling:
return self.global_pooling(x)
x = self.global_pooling(x)
if self.embedding_layer:
x = self.embedding_layer(x)
return x
else:
return x
......@@ -417,7 +437,7 @@ class ResNet50(tf.keras.Model):
g = f[inlayer.name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = [g[weight_name] for weight_name in weight_names]
print('Setting the weights for layer %s' % (inlayer.name))
logging.info('Setting the weights for layer %s', inlayer.name)
inlayer.set_weights(weight_values)
finally:
# Clean up the temporary file.
......@@ -437,3 +457,28 @@ class ResNet50(tf.keras.Model):
else:
logging.info('Layer %s does not have inner layers.',
layer.name)
def gem_pooling(feature_map, axis, power, threshold=1e-6):
"""Performs GeM (Generalized Mean) pooling.
See https://arxiv.org/abs/1711.02512 for a reference.
Args:
feature_map: Tensor of shape [batch, height, width, channels] for
the "channels_last" format or [batch, channels, height, width] for the
"channels_first" format.
axis: Dimensions to reduce.
power: Float, GeM power parameter.
threshold: Optional float, threshold to use for activations.
Returns:
pooled_feature_map: Tensor of shape [batch, 1, 1, channels] for the
"channels_last" format or [batch, channels, 1, 1] for the
"channels_first" format.
"""
return tf.pow(
tf.reduce_mean(tf.pow(tf.maximum(feature_map, threshold), power),
axis=axis,
keepdims=True),
1.0 / power)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the ResNet backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from delf.python.training.model import resnet50
class Resnet50Test(tf.test.TestCase):
def test_gem_pooling_works(self):
# Input feature map: Batch size = 2, height = 1, width = 2, depth = 2.
feature_map = tf.constant([[[[.0, 2.0], [1.0, -1.0]]],
[[[1.0, 100.0], [1.0, .0]]]],
dtype=tf.float32)
power = 2.0
threshold = .0
# Run tested function.
pooled_feature_map = resnet50.gem_pooling(feature_map=feature_map,
axis=[1, 2],
power=power,
threshold=threshold)
# Define expected result.
expected_pooled_feature_map = np.array([[[[0.707107, 1.414214]]],
[[[1.0, 70.710678]]]],
dtype=float)
# Compare actual and expected.
self.assertAllClose(pooled_feature_map, expected_pooled_feature_map)
if __name__ == '__main__':
tf.test.main()
......@@ -34,6 +34,7 @@ import tensorflow_probability as tfp
# Placeholder for internal import. Do not remove this line.
from delf.python.training.datasets import googlelandmarks as gld
from delf.python.training.model import delf_model
from delf.python.training.model import delg_model
FLAGS = flags.FLAGS
......@@ -57,6 +58,15 @@ flags.DEFINE_boolean('use_augmentation', True,
flags.DEFINE_string(
'imagenet_checkpoint', None,
'ImageNet checkpoint for ResNet backbone. If None, no checkpoint is used.')
flags.DEFINE_boolean('delg_global_features', False,
'Whether to train a DELG model.')
flags.DEFINE_float('delg_gem_power', 3.0, 'Power for Generalized Mean pooling.')
flags.DEFINE_integer('delg_embedding_layer_dim', 2048,
'Size of the FC whitening layer (embedding layer).')
flags.DEFINE_float('delg_scale_factor_init', 45.25,
('Initial value of the scaling factor of the cosine logits.'
'The default value is sqrt(2048).'))
flags.DEFINE_float('delg_arcface_margin', 0.1, 'ArcFace margin.')
def _record_accuracy(metric, logits, labels):
......@@ -90,6 +100,14 @@ def _attention_summaries(scores, global_step):
def create_model(num_classes):
"""Define DELF model, and initialize classifiers."""
if FLAGS.delg_global_features:
model = delg_model.Delg(block3_strides=FLAGS.block3_strides,
name='DELG',
gem_power=FLAGS.delg_gem_power,
embedding_layer_dim=FLAGS.delg_embedding_layer_dim,
scale_factor_init=FLAGS.delg_scale_factor_init,
arcface_margin=FLAGS.delg_arcface_margin)
else:
model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF')
model.init_classifiers(num_classes)
return model
......@@ -263,7 +281,12 @@ def main(argv):
for k, v in activations_zero_fractions.items():
tf.summary.scalar(k, v, step=global_step)
# Apply descriptor classifier.
# Apply descriptor classifier and report scale factor.
if FLAGS.delg_global_features:
logits = model.desc_classification(prelogits, labels)
tf.summary.scalar('desc/scale_factor', model.scale_factor,
step=global_step)
else:
logits = model.desc_classification(prelogits)
desc_loss = compute_loss(labels, logits)
......@@ -308,6 +331,9 @@ def main(argv):
blocks = {}
prelogits = model.backbone(
images, intermediates_dict=blocks, training=False)
if FLAGS.delg_global_features:
logits = model.desc_classification(prelogits, labels, training=False)
else:
logits = model.desc_classification(prelogits, training=False)
softmax_probabilities = tf.keras.layers.Softmax()(logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment