"tests/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d29b312c13088f14bd28360c230e039940381ec3"
Unverified Commit 23837ecf authored by Jekaterina Jaroslavceva's avatar Jekaterina Jaroslavceva Committed by GitHub
Browse files

Global model (#9902)

* Dataset utilities added.

* Global model definition

* Dataset modules added.

* Dataset modules fix.

* global features model training added

* global features fix

* Test dataset update

* PR fixes

* repo sync

* repo sync

* Syncing 2

* Syncing 2

* Global model definition added

* Global model definition added, synced

* Adding global model dataset related modules

* Global model training

* tensorboard module added

* linting issues fixed

* linting fixes.

* linting fixes.

* Fix for previous PR

* PR fixes

* Minor fixes

* Minor fixes

* Dataset download fix

* Comments fix

* sfm120k fix

* sfm120k fix

* names fix

* Update

* Update

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py

* PR fixes global model

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py

* Merge branch 'global_model_training'

# Conflicts:
#	research/delf/delf/python/datasets/generic_dataset.py
#	research/delf/delf/python/datasets/generic_dataset_test.py
#	research/delf/delf/python/datasets/sfm120k/__init__.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k.py
#	research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
#	research/delf/delf/python/datasets/tuples_dataset.py
#	research/delf/delf/python/datasets/tuples_dataset_test.py
#	research/delf/delf/python/training/global_features/__init__.py
#	research/delf/delf/python/training/global_features/train.py
#	research/delf/delf/python/training/model/global_model.py
#	research/delf/delf/python/training/model/global_model_test.py
#	research/delf/delf/python/training/tensorboard_utils.py
parent 5bac2ced
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for the global model training."""
import os
from absl import logging
import numpy as np
import tensorflow as tf
from delf.python.datasets.revisited_op import dataset as revisited_dataset
class AverageMeter():
"""Computes and stores the average and current value of loss."""
def __init__(self):
"""Initialization of the AverageMeter."""
self.reset()
def reset(self):
"""Resets all the values."""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""Updates values in the AverageMeter.
Args:
val: Float, loss value.
n: Integer, number of instances.
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def compute_metrics_and_print(dataset_name,
sorted_index_ids,
ground_truth,
desired_pr_ranks=None,
log=True):
"""Computes and logs ground-truth metrics for Revisited datasets.
Args:
dataset_name: String, name of the dataset.
sorted_index_ids: Integer NumPy array of shape [#queries, #index_images].
For each query, contains an array denoting the most relevant index images,
sorted from most to least relevant.
ground_truth: List containing ground-truth information for dataset. Each
entry is a dict corresponding to the ground-truth information for a query.
The dict has keys 'ok' and 'junk', mapping to a NumPy array of integers.
desired_pr_ranks: List of integers containing the desired precision/recall
ranks to be reported. E.g., if precision@1/recall@1 and
precision@10/recall@10 are desired, this should be set to [1, 10]. The
largest item should be <= #sorted_index_ids. Default: [1, 5, 10].
log: Whether to log results using logging.info().
Returns:
mAP: (metricsE, metricsM, metricsH) Tuple of the metrics for different
levels of complexity. Each metrics is a list containing:
mean_average_precision (float), mean_precisions (NumPy array of
floats, with shape [len(desired_pr_ranks)]), mean_recalls (NumPy array
of floats, with shape [len(desired_pr_ranks)]), average_precisions
(NumPy array of floats, with shape [#queries]), precisions (NumPy array of
floats, with shape [#queries, len(desired_pr_ranks)]), recalls (NumPy
array of floats, with shape [#queries, len(desired_pr_ranks)]).
Raises:
ValueError: If an unknown dataset name is provided as an argument.
"""
if dataset_name not in revisited_dataset.DATASET_NAMES:
raise ValueError('Unknown dataset: {}!'.format(dataset))
if desired_pr_ranks is None:
desired_pr_ranks = [1, 5, 10]
(easy_ground_truth, medium_ground_truth,
hard_ground_truth) = revisited_dataset.ParseEasyMediumHardGroundTruth(
ground_truth)
metrics_easy = revisited_dataset.ComputeMetrics(sorted_index_ids,
easy_ground_truth,
desired_pr_ranks)
metrics_medium = revisited_dataset.ComputeMetrics(sorted_index_ids,
medium_ground_truth,
desired_pr_ranks)
metrics_hard = revisited_dataset.ComputeMetrics(sorted_index_ids,
hard_ground_truth,
desired_pr_ranks)
debug_and_log(
'>> {}: mAP E: {}, M: {}, H: {}'.format(
dataset_name, np.around(metrics_easy[0] * 100, decimals=2),
np.around(metrics_medium[0] * 100, decimals=2),
np.around(metrics_hard[0] * 100, decimals=2)),
log=log)
debug_and_log(
'>> {}: mP@k{} E: {}, M: {}, H: {}'.format(
dataset_name, desired_pr_ranks,
np.around(metrics_easy[1] * 100, decimals=2),
np.around(metrics_medium[1] * 100, decimals=2),
np.around(metrics_hard[1] * 100, decimals=2)),
log=log)
return metrics_easy, metrics_medium, metrics_hard
def htime(time_difference):
"""Time formatting function.
Depending on the value of `time_difference` outputs time in an appropriate
time format.
Args:
time_difference: Float, time difference between the two events.
Returns:
time: String representing time in an appropriate time format.
"""
time_difference = round(time_difference)
days = time_difference // 86400
hours = time_difference // 3600 % 24
minutes = time_difference // 60 % 60
seconds = time_difference % 60
if days > 0:
return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds)
if hours > 0:
return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds)
if minutes > 0:
return '{:d}m {:d}s'.format(minutes, seconds)
return '{:d}s'.format(seconds)
def debug_and_log(msg, debug=True, log=True, debug_on_the_same_line=False):
"""Outputs `msg` to both stdout (if in the debug mode) and the log file.
Args:
msg: String, message to be logged.
debug: Bool, if True, will print `msg` to stdout.
log: Bool, if True, will redirect `msg` to the logfile.
debug_on_the_same_line: Bool, if True, will print `msg` to stdout without a
new line. When using this mode, logging to a logfile is disabled.
"""
if debug_on_the_same_line:
print(msg, end='')
return
if debug:
print(msg)
if log:
logging.info(msg)
def get_standard_keras_models():
"""Gets the standard keras model names.
Returns:
model_names: List, names of the standard keras models.
"""
model_names = sorted(
name for name in tf.keras.applications.__dict__
if not name.startswith('__') and
callable(tf.keras.applications.__dict__[name]))
return model_names
def create_model_directory(training_dataset, arch, pool, whitening, pretrained,
loss, loss_margin, optimizer, lr, weight_decay,
neg_num, query_size, pool_size, batch_size,
update_every, image_size, directory):
"""Based on the model parameters, creates the model directory.
If the model directory does not exist, the directory is created.
Args:
training_dataset: String, training dataset name.
arch: String, model architecture.
pool: String, pooling option.
whitening: Bool, whether the model is trained with global whitening.
pretrained: Bool, whether the model is initialized with the precomputed
weights.
loss: String, training loss type.
loss_margin: Float, loss margin.
optimizer: Sting, used optimizer.
lr: Float, initial learning rate.
weight_decay: Float, weight decay.
neg_num: Integer, Number of negative images per train/val tuple.
query_size: Integer, number of queries per one training epoch.
pool_size: Integer, size of the pool for hard negative mining.
batch_size: Integer, batch size.
update_every: Integer, frequency of the model weights update.
image_size: Integer, maximum size of longer image side used for training.
directory: String, destination where trained network should be saved.
Returns:
folder: String, path to the model folder.
"""
folder = '{}_{}_{}'.format(training_dataset, arch, pool)
if whitening:
folder += '_whiten'
if not pretrained:
folder += '_notpretrained'
folder += ('_{}_m{:.2f}_{}_lr{:.1e}_wd{:.1e}_nnum{}_qsize{}_psize{}_bsize{}'
'_uevery{}_imsize{}').format(loss, loss_margin, optimizer, lr,
weight_decay, neg_num, query_size,
pool_size, batch_size, update_every,
image_size)
folder = os.path.join(directory, folder)
debug_and_log(
'>> Creating directory if does not exist:\n>> \'{}\''.format(folder))
if not os.path.exists(folder):
os.makedirs(folder)
return folder
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CNN Image Retrieval model implementation based on the following papers:
[1] Fine-tuning CNN Image Retrieval with No Human Annotation,
Radenović F., Tolias G., Chum O., TPAMI 2018 [arXiv]
https://arxiv.org/abs/1711.02512
[2] CNN Image Retrieval Learns from BoW: Unsupervised Fine-Tuning with Hard
Examples, Radenović F., Tolias G., Chum O., ECCV 2016 [arXiv]
https://arxiv.org/abs/1604.02426
"""
import os
import pickle
import tensorflow as tf
from delf.python.datasets import generic_dataset
from delf.python.normalization_layers import normalization
from delf.python.pooling_layers import pooling as pooling_layers
from delf.python.training import global_features_utils
# Pre-computed global whitening, for most commonly used architectures.
# Using pre-computed whitening improves the speed of the convergence and the
# performance.
_WHITENING_CONFIG = {
'ResNet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval_tf'
'/SFM120k_ResNet50_gem_learned_whitening_config.pkl',
'ResNet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval_tf'
'/SFM120k_ResNet101_gem_learned_whitening_config.pkl',
'ResNet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval_tf'
'/SFM120k_ResNet152_gem_learned_whitening_config.pkl',
'VGG19': 'http://cmp.felk.cvut.cz/cnnimageretrieval_tf'
'/SFM120k_VGG19_gem_learned_whitening_config.pkl'
}
# Possible global pooling layers.
_POOLING = {
'mac': pooling_layers.MAC,
'spoc': pooling_layers.SPoC,
'gem': pooling_layers.GeM
}
# Output dimensionality for supported architectures.
_OUTPUT_DIM = {
'VGG16': 512,
'VGG19': 512,
'ResNet50': 2048,
'ResNet101': 2048,
'ResNet101V2': 2048,
'ResNet152': 2048,
'DenseNet121': 1024,
'DenseNet169': 1664,
'DenseNet201': 1920,
'EfficientNetB5': 2048,
'EfficientNetB7': 2560
}
class GlobalFeatureNet(tf.keras.Model):
"""Instantiates global model for image retrieval.
This class implements the [GlobalFeatureNet](
https://arxiv.org/abs/1711.02512) for image retrieval. The model uses a
user-defined model as a backbone.
"""
def __init__(self, architecture='ResNet101', pooling='gem',
whitening=False, pretrained=True, data_root=''):
"""GlobalFeatureNet network initialization.
Args:
architecture: Network backbone.
pooling: Pooling method used 'mac'/'spoc'/'gem'.
whitening: Bool, whether to use whitening.
pretrained: Bool, whether to initialize the network with the weights
pretrained on ImageNet.
data_root: String, path to the data folder where the precomputed
whitening is/will be saved in case `whitening` is True.
Raises:
ValueError: If `architecture` is not supported.
"""
if architecture not in _OUTPUT_DIM.keys():
raise ValueError("Architecture {} is not supported.".format(architecture))
super(GlobalFeatureNet, self).__init__()
# Get standard output dimensionality size.
dim = _OUTPUT_DIM[architecture]
if pretrained:
# Initialize with network pretrained on imagenet.
net_in = getattr(tf.keras.applications, architecture)(include_top=False,
weights="imagenet")
else:
# Initialize with random weights.
net_in = getattr(tf.keras.applications, architecture)(include_top=False,
weights=None)
# Initialize `feature_extractor`. Take only convolutions for
# `feature_extractor`, always end with ReLU to make last activations
# non-negative.
if architecture.lower().startswith('densenet'):
tmp_model = tf.keras.Sequential()
tmp_model.add(net_in)
net_in = tmp_model
net_in.add(tf.keras.layers.ReLU())
# Initialize pooling.
self.pool = _POOLING[pooling]()
# Initialize whitening.
if whitening:
if pretrained and architecture in _WHITENING_CONFIG:
# If precomputed whitening for the architecture exists,
# the fully-connected layer is going to be initialized according to
# the precomputed layer configuration.
global_features_utils.debug_and_log(
">> {}: for '{}' custom computed whitening '{}' is used."
.format(os.getcwd(), architecture,
os.path.basename(_WHITENING_CONFIG[architecture])))
# The layer configuration is downloaded to the `data_root` folder.
whiten_dir = os.path.join(data_root, architecture)
path = tf.keras.utils.get_file(fname=whiten_dir,
origin=_WHITENING_CONFIG[architecture])
# Whitening configuration is loaded.
with tf.io.gfile.GFile(path, 'rb') as learned_whitening_file:
whitening_config = pickle.load(learned_whitening_file)
# Whitening layer is initialized according to the configuration.
self.whiten = tf.keras.layers.Dense.from_config(whitening_config)
else:
# In case if no precomputed whitening exists for the chosen
# architecture, the fully-connected whitening layer is initialized
# with the random weights.
self.whiten = tf.keras.layers.Dense(dim, activation=None, use_bias=True)
global_features_utils.debug_and_log(
">> There is either no whitening computed for the "
"used network architecture or pretrained is False,"
" random weights are used.")
else:
self.whiten = None
# Create meta information to be stored in the network.
self.meta = {
'architecture': architecture,
'pooling': pooling,
'whitening': whitening,
'outputdim': dim
}
self.feature_extractor = net_in
self.normalize = normalization.L2Normalization()
def call(self, x, training=False):
"""Invokes the GlobalFeatureNet instance.
Args:
x: [B, H, W, C] Tensor with a batch of images.
training: Indicator of whether the forward pass is running in training
mode or not.
Returns:
out: [B, out_dim] Global descriptor.
"""
# Forward pass through the fully-convolutional backbone.
o = self.feature_extractor(x, training)
# Pooling.
o = self.pool(o)
# Normalization.
o = self.normalize(o)
# If whitening exists: the pooled global descriptor is whitened and
# re-normalized.
if self.whiten is not None:
o = self.whiten(o)
o = self.normalize(o)
return o
def meta_repr(self):
'''Provides high-level information about the network.
Returns:
meta: string with the information about the network (used
architecture, pooling type, whitening, outputdim).
'''
tmpstr = '(meta):\n'
tmpstr += '\tarchitecture: {}\n'.format(self.meta['architecture'])
tmpstr += '\tpooling: {}\n'.format(self.meta['pooling'])
tmpstr += '\twhitening: {}\n'.format(self.meta['whitening'])
tmpstr += '\toutputdim: {}\n'.format(self.meta['outputdim'])
return tmpstr
def extract_global_descriptors_from_list(net, images, image_size,
bounding_boxes=None, scales=[1.],
multi_scale_power=1., print_freq=10):
"""Extracting global descriptors from a list of images.
Args:
net: Model object, network for the forward pass.
images: Absolute image paths as strings.
image_size: Integer, defines the maximum size of longer image side.
bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images.
scales: List of float scales.
multi_scale_power: Float, multi-scale normalization power parameter.
print_freq: Printing frequency for debugging.
Returns:
descriptors: Global descriptors for the input images.
"""
# Creating dataset loader.
data = generic_dataset.ImagesFromList(root='', image_paths=images,
imsize=image_size,
bounding_boxes=bounding_boxes)
def _data_gen():
return (inst for inst in data)
loader = tf.data.Dataset.from_generator(_data_gen, output_types=(tf.float32))
loader = loader.batch(1)
# Extracting vectors.
descriptors = tf.zeros((0, net.meta['outputdim']))
for i, input in enumerate(loader):
if len(scales) == 1 and scales[0] == 1:
descriptors = tf.concat([descriptors, net(input)], 0)
else:
descriptors = tf.concat(
[descriptors, extract_multi_scale_descriptor(
net, input, scales, multi_scale_power)], 0)
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
global_features_utils.debug_and_log(
'\r>>>> {}/{} done...'.format((i + 1), len(images)),
debug_on_the_same_line=True)
global_features_utils.debug_and_log('', log=False)
descriptors = tf.transpose(descriptors, perm=[1, 0])
return descriptors
def extract_multi_scale_descriptor(net, input, scales, multi_scale_power):
"""Extracts the global descriptor multi scale.
Args:
net: Model object, network for the forward pass.
input: [B, H, W, C] input tensor in channel-last (BHWC) configuration.
scales: List of float scales.
multi_scale_power: Float, multi-scale normalization power parameter.
Returns:
descriptors: Multi-scale global descriptors for the input images.
"""
descriptors = tf.zeros(net.meta['outputdim'])
for s in scales:
if s == 1:
input_t = input
else:
output_shape = s * tf.shape(input)[1:3].numpy()
input_t = tf.image.resize(input, output_shape,
method='bilinear',
preserve_aspect_ratio=True)
descriptors += tf.pow(net(input_t), multi_scale_power)
descriptors /= len(scales)
descriptors = tf.pow(descriptors, 1. / multi_scale_power)
descriptors /= tf.norm(descriptors)
return descriptors
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the GlobalFeatureNet backbone."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
from PIL import Image
import tensorflow as tf
from delf.python.training.model import global_model
FLAGS = flags.FLAGS
class GlobalFeatureNetTest(tf.test.TestCase):
"""Tests for the GlobalFeatureNet backbone."""
def testInitModel(self):
"""Testing GlobalFeatureNet initialization."""
# Testing GlobalFeatureNet initialization.
model_params = {'architecture': 'ResNet101', 'pooling': 'gem',
'whitening': False, 'pretrained': True}
model = global_model.GlobalFeatureNet(**model_params)
expected_meta = {'architecture': 'ResNet101', 'pooling': 'gem',
'whitening': False, 'outputdim': 2048}
self.assertEqual(expected_meta, model.meta)
def testExtractVectors(self):
"""Tests extraction of global descriptors from list."""
# Initializing network for testing.
model_params = {'architecture': 'ResNet101', 'pooling': 'gem',
'whitening': False, 'pretrained': True}
model = global_model.GlobalFeatureNet(**model_params)
# Number of images to be created.
n = 2
image_paths = []
# Create `n` dummy images.
for i in range(n):
dummy_image = np.random.rand(1024, 750, 3) * 255
img_out = Image.fromarray(dummy_image.astype('uint8')).convert('RGB')
filename = os.path.join(FLAGS.test_tmpdir, 'test_image_{}.jpg'.format(i))
img_out.save(filename)
image_paths.append(filename)
descriptors = global_model.extract_global_descriptors_from_list(
model, image_paths, image_size=1024, bounding_boxes=None,
scales=[1., 3.], multi_scale_power=2, print_freq=1)
self.assertAllEqual([2048, 2], tf.shape(descriptors))
def testExtractMultiScale(self):
"""Tests multi-scale global descriptor extraction."""
# Initializing network for testing.
model_params = {'architecture': 'ResNet101', 'pooling': 'gem',
'whitening': False, 'pretrained': True}
model = global_model.GlobalFeatureNet(**model_params)
input = tf.random.uniform([2, 1024, 750, 3], dtype=tf.float32, seed=0)
descriptors = global_model.extract_multi_scale_descriptor(
model, input, scales=[1., 3.], multi_scale_power=2)
self.assertAllEqual([2, 2048], tf.shape(descriptors))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment