Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# import io
import os
import random
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import video_classification
class VideoClassificationTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self):
params = exp_factory.get_exp_config('video_classification_ucf101')
params.task.train_data.feature_shape = (8, 64, 64, 3)
params.task.validation_data.feature_shape = (8, 64, 64, 3)
params.task.model.backbone.resnet_3d.model_id = 50
classification_module = video_classification.VideoClassificationModule(
params, batch_size=1, input_image_size=[8, 64, 64])
return classification_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, module=None):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
images = np.random.randint(
low=0, high=255, size=(1, 8, 64, 64, 3), dtype=np.uint8)
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
return images, images
elif input_type == 'tf_example':
example = tfexample_utils.make_video_test_example(
image_shape=(64, 64, 3),
audio_shape=(20, 128),
label=random.randint(0, 100)).SerializeToString()
images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._decode_tf_example,
elems=tf.constant([example]),
fn_output_signature={
video_classification.video_input.IMAGE_KEY: tf.string,
}))
images = images[video_classification.video_input.IMAGE_KEY]
return [example], images
else:
raise ValueError(f'{input_type}')
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default']
images, images_tensor = self._get_dummy_input(input_type, module)
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._preprocess_image,
elems=images_tensor,
fn_output_signature={
'image': tf.float32,
}))
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__':
tf.test.main()
...@@ -24,7 +24,7 @@ from official.modeling import tf_utils ...@@ -24,7 +24,7 @@ from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.dataloaders import classification_input from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_classification_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_classification_decoder(params.tfds_name)
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = classification_input.Decoder( decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key, image_field_key=image_field_key, label_field_key=label_field_key,
......
...@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg ...@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import retinanet_input from official.vision.beta.dataloaders import retinanet_input
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tfds_detection_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task): ...@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
"""Build input dataset.""" """Build input dataset."""
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_detection_decoder(params.tfds_name)
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
......
...@@ -23,7 +23,7 @@ from official.core import task_factory ...@@ -23,7 +23,7 @@ from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import segmentation_input from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.dataloaders import tfds_segmentation_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = segmentation_input.Decoder() decoder = segmentation_input.Decoder()
......
...@@ -66,4 +66,5 @@ def main(_): ...@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# Lint as: python3 # Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning.""" """TensorFlow Model Garden Vision training driver with spatial partitioning."""
from typing import Sequence
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -33,19 +34,34 @@ from official.modeling import performance ...@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_computation_shape_for_model_parallelism(input_partition_dims): def get_computation_shape_for_model_parallelism(
"""Return computation shape to be used for TPUStrategy spatial partition.""" input_partition_dims: Sequence[int]) -> Sequence[int]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices = np.prod(input_partition_dims) num_logical_devices = np.prod(input_partition_dims)
if num_logical_devices == 1: if num_logical_devices == 1:
return [1, 1, 1, 1] return [1, 1, 1, 1]
if num_logical_devices == 2: elif num_logical_devices == 2:
return [1, 1, 1, 2] return [1, 1, 1, 2]
if num_logical_devices == 4: elif num_logical_devices == 4:
return [1, 2, 1, 2] return [1, 2, 1, 2]
if num_logical_devices == 8: elif num_logical_devices == 8:
return [2, 2, 1, 2] return [2, 2, 1, 2]
if num_logical_devices == 16: elif num_logical_devices == 16:
return [4, 2, 1, 2] return [4, 2, 1, 2]
else:
raise ValueError(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16' % num_logical_devices)
def create_distribution_strategy(distribution_strategy, def create_distribution_strategy(distribution_strategy,
......
# Image Classification # Image Classification
**Warning:** the features in the `image_classification/` folder have been fully
intergrated into vision/beta. Please use the [new code base](../beta/README.md).
This folder contains TF 2.0 model examples for image classification: This folder contains TF 2.0 model examples for image classification:
* [MNIST](#mnist) * [MNIST](#mnist)
......
...@@ -132,6 +132,9 @@ class IouSimilarity: ...@@ -132,6 +132,9 @@ class IouSimilarity:
Output shape: Output shape:
[M, N], or [B, M, N] [M, N], or [B, M, N]
""" """
boxes_1 = tf.cast(boxes_1, tf.float32)
boxes_2 = tf.cast(boxes_2, tf.float32)
boxes_1_rank = len(boxes_1.shape) boxes_1_rank = len(boxes_1.shape)
boxes_2_rank = len(boxes_2.shape) boxes_2_rank = len(boxes_2.shape)
if boxes_1_rank < 2 or boxes_1_rank > 3: if boxes_1_rank < 2 or boxes_1_rank > 3:
......
...@@ -14,24 +14,32 @@ ...@@ -14,24 +14,32 @@
"""Provides the `ExportSavedModel` action and associated helper classes.""" """Provides the `ExportSavedModel` action and associated helper classes."""
import re
from typing import Callable, Optional from typing import Callable, Optional
import tensorflow as tf import tensorflow as tf
def _id_key(filename):
_, id_num = filename.rsplit('-', maxsplit=1)
return int(id_num)
def _find_managed_files(base_name):
r"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$')
filenames = tf.io.gfile.glob(f'{base_name}-*')
filenames = filter(managed_file_regex.match, filenames)
return sorted(filenames, key=_id_key)
class _CounterIdFn: class _CounterIdFn:
"""Implements a counter-based ID function for `ExportFileManager`.""" """Implements a counter-based ID function for `ExportFileManager`."""
def __init__(self, base_name: str): def __init__(self, base_name: str):
filenames = tf.io.gfile.glob(f'{base_name}-*') managed_files = _find_managed_files(base_name)
max_counter = -1 self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0
for filename in filenames:
try:
_, file_number = filename.rsplit('-', maxsplit=1)
max_counter = max(max_counter, int(file_number))
except ValueError:
continue
self.value = max_counter + 1
def __call__(self): def __call__(self):
output = self.value output = self.value
...@@ -82,13 +90,7 @@ class ExportFileManager: ...@@ -82,13 +90,7 @@ class ExportFileManager:
`ExportFileManager` instance, sorted in increasing integer order of the `ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`. IDs returned by `next_id_fn`.
""" """
return _find_managed_files(self._base_name)
def id_key(name):
_, id_num = name.rsplit('-', maxsplit=1)
return int(id_num)
filenames = tf.io.gfile.glob(f'{self._base_name}-*')
return sorted(filenames, key=id_key)
def clean_up(self): def clean_up(self):
"""Cleans up old files matching `{base_name}-*`. """Cleans up old files matching `{base_name}-*`.
......
...@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
_id_sorted_file_base_names(directory.full_path), _id_sorted_file_base_names(directory.full_path),
['basename-200', 'basename-1000']) ['basename-200', 'basename-1000'])
def test_export_file_manager_managed_files(self):
directory = self.create_tempdir()
directory.create_file('basename-5')
directory.create_file('basename-10')
directory.create_file('basename-50')
directory.create_file('basename-1000')
directory.create_file('basename-9')
directory.create_file('basename-10-suffix')
base_name = os.path.join(directory.full_path, 'basename')
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertLen(manager.managed_files, 5)
self.assertEqual(manager.next_name(), f'{base_name}-1001')
manager.clean_up()
self.assertEqual(
manager.managed_files,
[f'{base_name}-10', f'{base_name}-50', f'{base_name}-1000'])
def test_export_saved_model(self): def test_export_saved_model(self):
directory = self.create_tempdir() directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename') base_name = os.path.join(directory.full_path, 'basename')
......
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for generic image dataset creation."""
import os
from delf.python.datasets import utils
class ImagesFromList():
"""A generic data loader that loads images from a list.
Supports images of different sizes.
"""
def __init__(self, root, image_paths, imsize=None, bounding_boxes=None,
loader=utils.default_loader):
"""ImagesFromList object initialization.
Args:
root: String, root directory path.
image_paths: List, relative image paths as strings.
imsize: Integer, defines the maximum size of longer image side.
bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images.
loader: Callable, a function to load an image given its path.
Raises:
ValueError: Raised if `image_paths` list is empty.
"""
# List of the full image filenames.
images_filenames = [os.path.join(root, image_path) for image_path in
image_paths]
if not images_filenames:
raise ValueError("Dataset contains 0 images.")
self.root = root
self.images = image_paths
self.imsize = imsize
self.images_filenames = images_filenames
self.bounding_boxes = bounding_boxes
self.loader = loader
def __getitem__(self, index):
"""Called to load an image at the given `index`.
Args:
index: Integer, image index.
Returns:
image: Tensor, loaded image.
"""
path = self.images_filenames[index]
if self.bounding_boxes is not None:
img = self.loader(path, self.imsize, self.bounding_boxes[index])
else:
img = self.loader(path, self.imsize)
return img
def __len__(self):
"""Implements the built-in function len().
Returns:
len: Number of images in the dataset.
"""
return len(self.images_filenames)
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for generic dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
from PIL import Image
import tensorflow as tf
from delf.python.datasets import generic_dataset
FLAGS = flags.FLAGS
class GenericDatasetTest(tf.test.TestCase):
"""Test functions for generic dataset."""
def testGenericDataset(self):
"""Tests loading dummy images from list."""
# Number of images to be created.
n = 2
image_names = []
# Create and save `n` dummy images.
for i in range(n):
dummy_image = np.random.rand(1024, 750, 3) * 255
img_out = Image.fromarray(dummy_image.astype('uint8')).convert('RGB')
filename = os.path.join(FLAGS.test_tmpdir,
'test_image_{}.jpg'.format(i))
img_out.save(filename)
image_names.append('test_image_{}.jpg'.format(i))
data = generic_dataset.ImagesFromList(root=FLAGS.test_tmpdir,
image_paths=image_names,
imsize=1024)
self.assertLen(data, n)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module exposing Sfm120k dataset for training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from delf.python.datasets.sfm120k import sfm120k
# pylint: enable=unused-import
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) download function."""
import os
import tensorflow as tf
def download_train(data_dir):
"""Checks, and, if required, downloads the necessary files for the training.
Checks if the data necessary for running the example training script exist.
If not, it downloads it in the following folder structure:
DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db
files.
DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db
files.
"""
# Create data folder if does not exist.
if not tf.io.gfile.exists(data_dir):
tf.io.gfile.mkdir(data_dir)
# Create datasets folder if does not exist.
datasets_dir = os.path.join(data_dir, 'train')
if not tf.io.gfile.exists(datasets_dir):
tf.io.gfile.mkdir(datasets_dir)
# Download folder train/retrieval-SfM-120k/.
src_dir = 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/ims'
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
download_file = 'ims.tar.gz'
if not tf.io.gfile.exists(dst_dir):
src_file = os.path.join(src_dir, download_file)
dst_file = os.path.join(dst_dir, download_file)
print('>> Image directory does not exist. Creating: {}'.format(dst_dir))
tf.io.gfile.makedirs(dst_dir)
print('>> Downloading ims.tar.gz...')
os.system('wget {} -O {}'.format(src_file, dst_file))
print('>> Extracting {}...'.format(dst_file))
os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir))
print('>> Extracted, deleting {}...'.format(dst_file))
os.system('rm {}'.format(dst_file))
# Create symlink for train/retrieval-SfM-30k/.
dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims')
if not (tf.io.gfile.exists(dst_dir) or os.path.islink(dst_dir)):
tf.io.gfile.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k'))
os.system('ln -s {} {}'.format(dst_dir_old, dst_dir))
print(
'>> Created symbolic link from retrieval-SfM-120k/ims to '
'retrieval-SfM-30k/ims')
# Download db files.
src_dir = 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/dbs'
datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k']
for dataset in datasets:
dst_dir = os.path.join(datasets_dir, dataset)
if dataset == 'retrieval-SfM-120k':
download_files = ['{}.pkl'.format(dataset),
'{}-whiten.pkl'.format(dataset)]
download_eccv2020 = '{}-val-eccv2020.pkl'.format(dataset)
elif dataset == 'retrieval-SfM-30k':
download_files = ['{}-whiten.pkl'.format(dataset)]
download_eccv2020 = None
if not tf.io.gfile.exists(dst_dir):
print('>> Dataset directory does not exist. Creating: {}'.format(
dst_dir))
tf.io.gfile.mkdir(dst_dir)
for i in range(len(download_files)):
src_file = os.path.join(src_dir, download_files[i])
dst_file = os.path.join(dst_dir, download_files[i])
if not os.path.isfile(dst_file):
print('>> DB file {} does not exist. Downloading...'.format(
download_files[i]))
os.system('wget {} -O {}'.format(src_file, dst_file))
if download_eccv2020:
eccv2020_dst_file = os.path.join(dst_dir, download_eccv2020)
if not os.path.isfile(eccv2020_dst_file):
eccv2020_src_dir = \
"http://ptak.felk.cvut.cz/personal/toliageo/share/how/dataset/"
eccv2020_dst_file = os.path.join(dst_dir, download_eccv2020)
eccv2020_src_file = os.path.join(eccv2020_src_dir,
download_eccv2020)
os.system('wget {} -O {}'.format(eccv2020_src_file,
eccv2020_dst_file))
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) module.
[1] From Single Image Query to Detailed 3D Reconstruction.
Johannes L. Schonberger, Filip Radenovic, Ondrej Chum, Jan-Michael Frahm.
The related paper can be found at: https://ieeexplore.ieee.org/document/7299148.
"""
import os
import pickle
import tensorflow as tf
from delf.python.datasets import tuples_dataset
from delf.python.datasets import utils
def id2filename(image_id, prefix):
"""Creates a training image path out of its id name.
Used for the image mapping in the Sfm120k datset.
Args:
image_id: String, image id.
prefix: String, root directory where images are saved.
Returns:
filename: String, full image filename.
"""
if prefix:
return os.path.join(prefix, image_id[-2:], image_id[-4:-2], image_id[-6:-4],
image_id)
else:
return os.path.join(image_id[-2:], image_id[-4:-2], image_id[-6:-4],
image_id)
class _Sfm120k(tuples_dataset.TuplesDataset):
"""Structure-from-Motion (Sfm120k) dataset instance.
The dataset contains the image names lists for training and validation,
the cluster ID (3D model ID) for each image and indices forming
query-positive pairs of images. The images are loaded per epoch and resized
on the fly to the desired dimensionality.
"""
def __init__(self, mode, data_root, imsize=None, num_negatives=5,
num_queries=2000, pool_size=20000, loader=utils.default_loader,
eccv2020=False):
"""Structure-from-Motion (Sfm120k) dataset initialization.
Args:
mode: Either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Raises:
ValueError: Raised if `mode` is not one of 'train' or 'val'.
"""
if mode not in ['train', 'val']:
raise ValueError(
"`mode` argument should be either 'train' or 'val', passed as a "
"String.")
# Setting up the paths for the dataset.
if eccv2020:
name = "retrieval-SfM-120k-val-eccv2020"
else:
name = "retrieval-SfM-120k"
db_root = os.path.join(data_root, 'train/retrieval-SfM-120k')
ims_root = os.path.join(db_root, 'ims/')
# Loading the dataset db file.
db_filename = os.path.join(db_root, '{}.pkl'.format(name))
with tf.io.gfile.GFile(db_filename, 'rb') as f:
db = pickle.load(f)[mode]
# Setting full paths for the dataset images.
self.images = [id2filename(img_name, None) for
img_name in db['cids']]
# Initializing tuples dataset.
super().__init__(name, mode, db_root, imsize, num_negatives, num_queries,
pool_size, loader, ims_root)
def Sfm120kInfo(self):
"""Metadata for the Sfm120k dataset.
The dataset contains the image names lists for training and
validation, the cluster ID (3D model ID) for each image and indices
forming query-positive pairs of images. The images are loaded per epoch
and resized on the fly to the desired dimensionality.
Returns:
info: dictionary with the dataset parameters.
"""
info = {'train': {'clusters': 91642, 'pidxs': 181697, 'qidxs': 181697},
'val': {'clusters': 6403, 'pidxs': 1691, 'qidxs': 1691}}
return info
def CreateDataset(mode, data_root, imsize=None, num_negatives=5,
num_queries=2000, pool_size=20000,
loader=utils.default_loader, eccv2020=False):
'''Creates Structure-from-Motion (Sfm120k) dataset.
Args:
mode: String, either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Returns:
sfm120k: Sfm120k dataset instance.
'''
return _Sfm120k(mode, data_root, imsize, num_negatives, num_queries,
pool_size, loader, eccv2020)
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Sfm120k dataset module."""
import tensorflow as tf
from delf.python.datasets.sfm120k import sfm120k
class Sfm120kTest(tf.test.TestCase):
"""Tests for Sfm120k dataset module."""
def testId2Filename(self):
"""Tests conversion of image id to full path mapping."""
image_id = "29fdc243aeb939388cfdf2d081dc080e"
prefix = "train/retrieval-SfM-120k/ims/"
path = sfm120k.id2filename(image_id, prefix)
expected_path = "train/retrieval-SfM-120k/ims/0e/08/dc" \
"/29fdc243aeb939388cfdf2d081dc080e"
self.assertEqual(path, expected_path)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tuple dataset module.
Based on the Radenovic et al. ECCV16: CNN image retrieval learns from BoW.
For more information refer to https://arxiv.org/abs/1604.02426.
"""
import os
import pickle
import numpy as np
import tensorflow as tf
from delf.python.datasets import utils as image_loading_utils
from delf.python.training import global_features_utils
from delf.python.training.model import global_model
class TuplesDataset():
"""Data loader that loads training and validation tuples.
After initialization, the function create_epoch_tuples() should be called to
create the dataset tuples. After that, the dataset can be iterated through
using next() function.
Tuples are based on Radenovic et al. ECCV16 work: CNN image retrieval
learns from BoW. For more information refer to
https://arxiv.org/abs/1604.02426.
"""
def __init__(self, name, mode, data_root, imsize=None, num_negatives=5,
num_queries=2000, pool_size=20000,
loader=image_loading_utils.default_loader, ims_root=None):
"""TuplesDataset object initialization.
Args:
name: String, dataset name. I.e. 'retrieval-sfm-120k'.
mode: 'train' or 'val' for training and validation parts of dataset.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side transform.
num_negatives: Integer, number of negative images for a query image in a
training tuple.
num_queries: Integer, number of query images to be processed in one epoch.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are re-mined.
loader: Callable, a function to load an image given its path.
ims_root: String, image root directory.
Raises:
ValueError: If mode is not either 'train' or 'val'.
"""
if mode not in ['train', 'val']:
raise ValueError(
"`mode` argument should be either 'train' or 'val', passed as a "
"String.")
# Loading db.
db_filename = os.path.join(data_root, '{}.pkl'.format(name))
with tf.io.gfile.GFile(db_filename, 'rb') as f:
db = pickle.load(f)[mode]
# Initializing tuples dataset.
self._ims_root = data_root if ims_root is None else ims_root
self._name = name
self._mode = mode
self._imsize = imsize
self._clusters = db['cluster']
self._query_pool = db['qidxs']
self._positive_pool = db['pidxs']
if not hasattr(self, 'images'):
self.images = db['ids']
# Size of training subset for an epoch.
self._num_negatives = num_negatives
self._num_queries = min(num_queries, len(self._query_pool))
self._pool_size = min(pool_size, len(self.images))
self._qidxs = None
self._pidxs = None
self._nidxs = None
self._loader = loader
self._print_freq = 10
# Indexer for the iterator.
self._n = 0
def __iter__(self):
"""Function for making TupleDataset an iterator.
Returns:
iter: The iterator object itself (TupleDataset).
"""
return self
def __next__(self):
"""Function for making TupleDataset an iterator.
Returns:
next: The next item in the sequence (next dataset image tuple).
"""
if self._n < len(self._qidxs):
result = self.__getitem__(self._n)
self._n += 1
return result
else:
raise StopIteration
def _img_names_to_full_path(self, image_list):
"""Converts list of image names to the list of full paths to the images.
Args:
image_list: Image names, either a list or a single image path.
Returns:
image_full_paths: List of full paths to the images.
"""
if not isinstance(image_list, list):
return os.path.join(self._ims_root, image_list)
return [os.path.join(self._ims_root, img_name) for img_name in image_list]
def __getitem__(self, index):
"""Called to load an image tuple at the given `index`.
Args:
index: Integer, index.
Returns:
output: Tuple [q,p,n1,...,nN, target], loaded 'train'/'val' tuple at
index of qidxs. `q` is the query image tensor, `p` is the
corresponding positive image tensor, `n1`,...,`nN` are the negatives
associated with the query. `target` is a tensor (with the shape [2+N])
of integer labels corresponding to the tuple list: query (-1),
positive (1), negative (0).
Raises:
ValueError: Raised if the query indexes list `qidxs` is empty.
"""
if self.__len__() == 0:
raise ValueError(
"List `qidxs` is empty. Run `dataset.create_epoch_tuples(net)` "
"method to create subset for `train`/`val`.")
output = []
# Query image.
output.append(self._loader(
self._img_names_to_full_path(self.images[self._qidxs[index]]),
self._imsize))
# Positive image.
output.append(self._loader(
self._img_names_to_full_path(self.images[self._pidxs[index]]),
self._imsize))
# Negative images.
for nidx in self._nidxs[index]:
output.append(self._loader(
self._img_names_to_full_path(self.images[nidx]),
self._imsize))
# Labels for the query (-1), positive (1), negative (0) images in the tuple.
target = tf.convert_to_tensor([-1, 1] + [0] * self._num_negatives)
output.append(target)
return tuple(output)
def __len__(self):
"""Called to implement the built-in function len().
Returns:
len: Integer, number of query images.
"""
if self._qidxs is None:
return 0
return len(self._qidxs)
def __repr__(self):
"""Metadata for the TupleDataset.
Returns:
meta: String, containing TupleDataset meta.
"""
fmt_str = self.__class__.__name__ + '\n'
fmt_str += '\tName and mode: {} {}\n'.format(self._name, self._mode)
fmt_str += '\tNumber of images: {}\n'.format(len(self.images))
fmt_str += '\tNumber of training tuples: {}\n'.format(len(self._query_pool))
fmt_str += '\tNumber of negatives per tuple: {}\n'.format(
self._num_negatives)
fmt_str += '\tNumber of tuples processed in an epoch: {}\n'.format(
self._num_queries)
fmt_str += '\tPool size for negative remining: {}\n'.format(self._pool_size)
return fmt_str
def create_epoch_tuples(self, net):
"""Creates epoch tuples with the hard-negative re-mining.
Negative examples are selected from clusters different than the cluster
of the query image, as the clusters are ideally non-overlaping. For
every query image we choose hard-negatives, that is, non-matching images
with the most similar descriptor. Hard-negatives depend on the current
CNN parameters. K-nearest neighbors from all non-matching images are
selected. Query images are selected randomly. Positives examples are
fixed for the related query image during the whole training process.
Args:
net: Model, network to be used for negative re-mining.
Raises:
ValueError: If the pool_size is smaller than the number of negative
images per tuple.
Returns:
avg_l2: Float, average negative L2-distance.
"""
self._n = 0
if self._num_negatives < self._pool_size:
raise ValueError("Unable to create epoch tuples. Negative pool_size "
"should be larger than the number of negative images "
"per tuple.")
global_features_utils.debug_and_log(
'>> Creating tuples for an epoch of {}-{}...'.format(self._name,
self._mode),
True)
global_features_utils.debug_and_log(">> Used network: ", True)
global_features_utils.debug_and_log(net.meta_repr(), True)
## Selecting queries.
# Draw `num_queries` random queries for the tuples.
idx_list = np.arange(len(self._query_pool))
np.random.shuffle(idx_list)
idxs2query_pool = idx_list[:self._num_queries]
self._qidxs = [self._query_pool[i] for i in idxs2query_pool]
## Selecting positive pairs.
# Positives examples are fixed for each query during the whole training
# process.
self._pidxs = [self._positive_pool[i] for i in idxs2query_pool]
## Selecting negative pairs.
# If `num_negatives` = 0 create dummy nidxs.
# Useful when only positives used for training.
if self._num_negatives == 0:
self._nidxs = [[] for _ in range(len(self._qidxs))]
return 0
# Draw pool_size random images for pool of negatives images.
neg_idx_list = np.arange(len(self.images))
np.random.shuffle(neg_idx_list)
neg_images_idxs = neg_idx_list[:self._pool_size]
global_features_utils.debug_and_log(
'>> Extracting descriptors for query images...', debug=True)
img_list = self._img_names_to_full_path([self.images[i] for i in
self._qidxs])
qvecs = global_model.extract_global_descriptors_from_list(
net,
images=img_list,
image_size=self._imsize,
print_freq=self._print_freq)
global_features_utils.debug_and_log(
'>> Extracting descriptors for negative pool...', debug=True)
poolvecs = global_model.extract_global_descriptors_from_list(
net,
images=self._img_names_to_full_path([self.images[i] for i in
neg_images_idxs]),
image_size=self._imsize,
print_freq=self._print_freq)
global_features_utils.debug_and_log('>> Searching for hard negatives...',
debug=True)
# Compute dot product scores and ranks.
scores = tf.linalg.matmul(poolvecs, qvecs, transpose_a=True)
ranks = tf.argsort(scores, axis=0, direction='DESCENDING')
sum_ndist = 0.
n_ndist = 0.
# Selection of negative examples.
self._nidxs = []
for q, qidx in enumerate(self._qidxs):
# We are not using the query cluster, those images are potentially
# positive.
qcluster = self._clusters[qidx]
clusters = [qcluster]
nidxs = []
rank = 0
while len(nidxs) < self._num_negatives:
if rank >= tf.shape(ranks)[0]:
raise ValueError("Unable to create epoch tuples. Number of required "
"negative images is larger than the number of "
"clusters in the dataset.")
potential = neg_images_idxs[ranks[rank, q]]
# Take at most one image from the same cluster.
if not self._clusters[potential] in clusters:
nidxs.append(potential)
clusters.append(self._clusters[potential])
dist = tf.norm(qvecs[:, q] - poolvecs[:, ranks[rank, q]],
axis=0).numpy()
sum_ndist += dist
n_ndist += 1
rank += 1
self._nidxs.append(nidxs)
global_features_utils.debug_and_log(
'>> Average negative l2-distance: {:.2f}'.format(
sum_ndist / n_ndist))
# Return average negative L2-distance.
return sum_ndist / n_ndist
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"Tests for the tuples dataset module."
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
from PIL import Image
import tensorflow as tf
import pickle
from delf.python.datasets import tuples_dataset
from delf.python.training.model import global_model
FLAGS = flags.FLAGS
class TuplesDatasetTest(tf.test.TestCase):
"""Tests for tuples dataset module."""
def testCreateEpochTuples(self):
"""Tests epoch tuple creation."""
# Create a tuples dataset instance.
name = 'test_dataset'
num_queries = 1
pool_size = 5
num_negatives = 2
# Create a ground truth .pkl file.
gnd = {
'train': {'ids': [str(i) + '.png' for i in range(2 * num_queries + pool_size)],
'cluster': [0, 0, 1, 2, 3, 4, 5],
'qidxs': [0], 'pidxs': [1]}}
gnd_name = name + '.pkl'
with tf.io.gfile.GFile(os.path.join(FLAGS.test_tmpdir, gnd_name),
'wb') as gnd_file:
pickle.dump(gnd, gnd_file)
# Create random images for the dataset.
for i in range(2 * num_queries + pool_size):
dummy_image = np.random.rand(1024, 750, 3) * 255
img_out = Image.fromarray(dummy_image.astype('uint8')).convert('RGB')
filename = os.path.join(FLAGS.test_tmpdir, '{}.png'.format(i))
img_out.save(filename)
dataset = tuples_dataset.TuplesDataset(
name=name,
data_root=FLAGS.test_tmpdir,
mode='train',
imsize=1024,
num_negatives=num_negatives,
num_queries=num_queries,
pool_size=pool_size
)
# Assert that initially no negative images are set.
self.assertIsNone(dataset._nidxs)
# Initialize a network for negative re-mining.
model_params = {'architecture': 'ResNet101', 'pooling': 'gem',
'whitening': False, 'pretrained': True}
model = global_model.GlobalFeatureNet(**model_params)
avg_neg_distance = dataset.create_epoch_tuples(model)
# Check that an appropriate number of negative images has been chosen per
# query.
self.assertAllEqual(tf.shape(dataset._nidxs), [num_queries, num_negatives])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global model training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training script for Global Features model."""
import math
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from delf.python.datasets.sfm120k import dataset_download
from delf.python.datasets.sfm120k import sfm120k
from delf.python.training import global_features_utils
from delf.python.training import tensorboard_utils
from delf.python.training.global_features import train_utils
from delf.python.training.losses import ranking_losses
from delf.python.training.model import global_model
_LOSS_NAMES = ['contrastive', 'triplet']
_MODEL_NAMES = global_features_utils.get_standard_keras_models()
_OPTIMIZER_NAMES = ['sgd', 'adam']
_POOL_NAMES = ['mac', 'spoc', 'gem']
_PRECOMPUTE_WHITEN_NAMES = ['retrieval-SfM-30k', 'retrieval-SfM-120k']
_TEST_DATASET_NAMES = ['roxford5k', 'rparis6k']
_TRAINING_DATASET_NAMES = ['retrieval-SfM-120k']
_VALIDATION_TYPES = ['standard', 'eccv2020']
FLAGS = flags.FLAGS
flags.DEFINE_boolean('debug', False, 'Debug mode.')
# Export directory, training and val datasets, test datasets.
flags.DEFINE_string('data_root', "data",
'Absolute path to the folder containing training data.')
flags.DEFINE_string('directory', "data",
'Destination where trained network should be saved.')
flags.DEFINE_enum('training_dataset', 'retrieval-SfM-120k',
_TRAINING_DATASET_NAMES, 'Training dataset: ' +
' | '.join(_TRAINING_DATASET_NAMES) + '.')
flags.DEFINE_enum('validation_type', None, _VALIDATION_TYPES,
'Type of the evaluation to use. Either `None`, `standard` '
'or `eccv2020`.')
flags.DEFINE_list('test_datasets', 'roxford5k,rparis6k',
'Comma separated list of test datasets: ' +
' | '.join(_TEST_DATASET_NAMES) + '.')
flags.DEFINE_enum('precompute_whitening', None, _PRECOMPUTE_WHITEN_NAMES,
'Dataset used to learn whitening: ' +
' | '.join(_PRECOMPUTE_WHITEN_NAMES) + '.')
flags.DEFINE_integer('test_freq', 5,
'Run test evaluation every N epochs.')
flags.DEFINE_list('multiscale', [1.],
'Use multiscale vectors for testing, ' +
' examples: 1 | 1,1/2**(1/2),1/2 | 1,2**(1/2),1/2**(1/2)]. '
'Pass as a string of comma separated values.')
# Network architecture and initialization options.
flags.DEFINE_enum('arch', 'ResNet101', _MODEL_NAMES,
'Model architecture: ' + ' | '.join(_MODEL_NAMES) + '.')
flags.DEFINE_enum('pool', 'gem', _POOL_NAMES,
'Pooling options: ' + ' | '.join(_POOL_NAMES) + '.')
flags.DEFINE_bool('whitening', False,
'Whether to train model with learnable whitening ('
'linear layer) after the pooling.')
flags.DEFINE_bool('pretrained', True,
'Whether to initialize model with random weights ('
'default: pretrained on imagenet).')
flags.DEFINE_enum('loss', 'contrastive', _LOSS_NAMES,
'Training loss options: ' + ' | '.join(_LOSS_NAMES) + '.')
flags.DEFINE_float('loss_margin', 0.7, 'Loss margin.')
# train/val options specific for image retrieval learning.
flags.DEFINE_integer('image_size', 1024,
'Maximum size of longer image side used for training.')
flags.DEFINE_integer('neg_num', 5, 'Number of negative images per train/val '
'tuple.')
flags.DEFINE_integer('query_size', 2000,
'Number of queries randomly drawn per one training epoch.')
flags.DEFINE_integer('pool_size', 20000,
'Size of the pool for hard negative mining.')
# Standard training/validation options.
flags.DEFINE_string('gpu_id', '0', 'GPU id used for training.')
flags.DEFINE_integer('epochs', 100, 'Number of total epochs to run.')
flags.DEFINE_integer('batch_size', 5,
'Number of (q,p,n1,...,nN) tuples in a mini-batch.')
flags.DEFINE_integer('update_every', 1,
'Update model weights every N batches, used to handle '
'relatively large batches, batch_size effectively '
'becomes update_every `x` batch_size.')
flags.DEFINE_enum('optimizer', 'adam', _OPTIMIZER_NAMES,
'Optimizer options: ' + ' | '.join(_OPTIMIZER_NAMES) + '.')
flags.DEFINE_float('lr', 1e-6, 'Initial learning rate.')
flags.DEFINE_float('momentum', 0.9, 'Momentum.')
flags.DEFINE_float('weight_decay', 1e-6, 'Weight decay.')
flags.DEFINE_bool('resume', False,
'Whether to start from the latest checkpoint in the logdir.')
flags.DEFINE_bool('launch_tensorboard', False, 'Whether to launch tensorboard.')
def main(argv):
if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.')
# Manually check if there are unknown test datasets and if the dataset
# ground truth files are downloaded.
for dataset in FLAGS.test_datasets:
if dataset not in _TEST_DATASET_NAMES:
raise ValueError('Unsupported or unknown test dataset: {}.'.format(
dataset))
test_data_config = os.path.join(FLAGS.data_root,
'gnd_{}.pkl'.format(dataset))
if not tf.io.gfile.exists(test_data_config):
raise ValueError(
'{} ground truth file at {} not found. Please download it '
'according to '
'the DELG instructions.'.format(dataset, FLAGS.data_root))
# Check if train dataset is downloaded and download it if not found.
dataset_download.download_train(FLAGS.data_root)
# Creating model export directory if it does not exist.
model_directory = global_features_utils.create_model_directory(
FLAGS.training_dataset, FLAGS.arch, FLAGS.pool, FLAGS.whitening,
FLAGS.pretrained, FLAGS.loss, FLAGS.loss_margin, FLAGS.optimizer,
FLAGS.lr, FLAGS.weight_decay, FLAGS.neg_num, FLAGS.query_size,
FLAGS.pool_size, FLAGS.batch_size, FLAGS.update_every,
FLAGS.image_size, FLAGS.directory)
# Setting up logging directory, same as where the model is stored.
logging.get_absl_handler().use_absl_log_file('absl_logging', model_directory)
# Set cuda visible device.
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id
global_features_utils.debug_and_log('>> Num GPUs Available: {}'.format(
len(tf.config.experimental.list_physical_devices('GPU'))),
FLAGS.debug)
# Set random seeds.
tf.random.set_seed(0)
np.random.seed(0)
# Initialize the model.
if FLAGS.pretrained:
global_features_utils.debug_and_log(
'>> Using pre-trained model \'{}\''.format(FLAGS.arch))
else:
global_features_utils.debug_and_log(
'>> Using model from scratch (random weights) \'{}\'.'.format(
FLAGS.arch))
model_params = {'architecture': FLAGS.arch, 'pooling': FLAGS.pool,
'whitening': FLAGS.whitening, 'pretrained': FLAGS.pretrained,
'data_root': FLAGS.data_root}
model = global_model.GlobalFeatureNet(**model_params)
# Freeze running mean and std in batch normalization layers.
# We do training one image at a time to improve memory requirements of
# the network; therefore, the computed statistics would not be per a
# batch. Instead, we choose freezing - setting the parameters of all
# batch norm layers in the network to non-trainable (i.e., using original
# imagenet statistics).
for layer in model.feature_extractor.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = False
global_features_utils.debug_and_log('>> Network initialized.')
global_features_utils.debug_and_log('>> Loss: {}.'.format(FLAGS.loss))
# Define the loss function.
if FLAGS.loss == 'contrastive':
criterion = ranking_losses.ContrastiveLoss(margin=FLAGS.loss_margin)
elif FLAGS.loss == 'triplet':
criterion = ranking_losses.TripletLoss(margin=FLAGS.loss_margin)
else:
raise ValueError('Loss {} not available.'.format(FLAGS.loss))
# Defining parameters for the training.
# When pre-computing whitening, we run evaluation before the network training
# and the `start_epoch` is set to 0. In other cases, we start from epoch 1.
start_epoch = 1
exp_decay = math.exp(-0.01)
decay_steps = FLAGS.query_size / FLAGS.batch_size
# Define learning rate decay schedule.
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=FLAGS.lr,
decay_steps=decay_steps,
decay_rate=exp_decay)
# Define the optimizer.
if FLAGS.optimizer == 'sgd':
opt = tfa.optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.SGD)
optimizer = opt(weight_decay=FLAGS.weight_decay,
learning_rate=lr_scheduler, momentum=FLAGS.momentum)
elif FLAGS.optimizer == 'adam':
opt = tfa.optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.Adam)
optimizer = opt(weight_decay=FLAGS.weight_decay, learning_rate=lr_scheduler)
else:
raise ValueError('Optimizer {} not available.'.format(FLAGS.optimizer))
# Initializing logging.
writer = tf.summary.create_file_writer(model_directory)
tf.summary.experimental.set_step(1)
# Setting up the checkpoint manager.
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint,
model_directory,
max_to_keep=10,
keep_checkpoint_every_n_hours=3)
if FLAGS.resume:
# Restores the checkpoint, if existing.
global_features_utils.debug_and_log('>> Continuing from a checkpoint.')
checkpoint.restore(manager.latest_checkpoint)
# Launching tensorboard if required.
if FLAGS.launch_tensorboard:
tensorboard = tf.keras.callbacks.TensorBoard(model_directory)
tensorboard.set_model(model=model)
tensorboard_utils.launch_tensorboard(log_dir=model_directory)
# Log flags used.
global_features_utils.debug_and_log('>> Running training script with:')
global_features_utils.debug_and_log('>> logdir = {}'.format(model_directory))
if FLAGS.training_dataset.startswith('retrieval-SfM-120k'):
train_dataset = sfm120k.CreateDataset(
data_root=FLAGS.data_root,
mode='train',
imsize=FLAGS.image_size,
num_negatives=FLAGS.neg_num,
num_queries=FLAGS.query_size,
pool_size=FLAGS.pool_size
)
if FLAGS.validation_type is not None:
val_dataset = sfm120k.CreateDataset(
data_root=FLAGS.data_root,
mode='val',
imsize=FLAGS.image_size,
num_negatives=FLAGS.neg_num,
num_queries=float('Inf'),
pool_size=float('Inf'),
eccv2020=True if FLAGS.validation_type == 'eccv2020' else False
)
train_dataset_output_types = [tf.float32 for i in range(2 + FLAGS.neg_num)]
train_dataset_output_types.append(tf.int32)
global_features_utils.debug_and_log(
'>> Training the {} network'.format(model_directory))
global_features_utils.debug_and_log('>> GPU ids: {}'.format(FLAGS.gpu_id))
with writer.as_default():
# Precompute whitening if needed.
if FLAGS.precompute_whitening is not None:
epoch = 0
train_utils.test_retrieval(
FLAGS.test_datasets, model, writer=writer,
epoch=epoch, model_directory=model_directory,
precompute_whitening=FLAGS.precompute_whitening,
data_root=FLAGS.data_root,
multiscale=FLAGS.multiscale)
for epoch in range(start_epoch, FLAGS.epochs + 1):
# Set manual seeds per epoch.
np.random.seed(epoch)
tf.random.set_seed(epoch)
# Find hard-negatives.
# While hard-positive examples are fixed during the whole training
# process and are randomly chosen from every epoch; hard-negatives
# depend on the current CNN parameters and are re-mined once per epoch.
avg_neg_distance = train_dataset.create_epoch_tuples(model)
def _train_gen():
return (inst for inst in train_dataset)
train_loader = tf.data.Dataset.from_generator(
_train_gen,
output_types=tuple(train_dataset_output_types))
loss = train_utils.train_val_one_epoch(
loader=iter(train_loader), model=model,
criterion=criterion, optimizer=optimizer, epoch=epoch,
batch_size=FLAGS.batch_size, query_size=FLAGS.query_size,
neg_num=FLAGS.neg_num, update_every=FLAGS.update_every,
debug=FLAGS.debug)
# Write a scalar summary.
tf.summary.scalar('train_epoch_loss', loss, step=epoch)
# Forces summary writer to send any buffered data to storage.
writer.flush()
# Evaluate on validation set.
if FLAGS.validation_type is not None and (epoch % FLAGS.test_freq == 0 or
epoch == 1):
avg_neg_distance = val_dataset.create_epoch_tuples(model,
model_directory)
def _val_gen():
return (inst for inst in val_dataset)
val_loader = tf.data.Dataset.from_generator(
_val_gen, output_types=tuple(train_dataset_output_types))
loss = train_utils.train_val_one_epoch(
loader=iter(val_loader), model=model,
criterion=criterion, optimizer=None,
epoch=epoch, train=False, batch_size=FLAGS.batch_size,
query_size=FLAGS.query_size, neg_num=FLAGS.neg_num,
update_every=FLAGS.update_every, debug=FLAGS.debug)
tf.summary.scalar('val_epoch_loss', loss, step=epoch)
writer.flush()
# Evaluate on test datasets every test_freq epochs.
if epoch == 1 or epoch % FLAGS.test_freq == 0:
train_utils.test_retrieval(
FLAGS.test_datasets, model, writer=writer, epoch=epoch,
model_directory=model_directory,
precompute_whitening=FLAGS.precompute_whitening,
data_root=FLAGS.data_root, multiscale=FLAGS.multiscale)
# Saving checkpoints and model weights.
try:
save_path = manager.save(checkpoint_number=epoch)
global_features_utils.debug_and_log(
'Saved ({}) at {}'.format(epoch, save_path))
filename = os.path.join(model_directory,
'checkpoint_epoch_{}.h5'.format(epoch))
model.save_weights(filename, save_format='h5')
global_features_utils.debug_and_log(
'Saved weights ({}) at {}'.format(epoch, filename))
except Exception as ex:
global_features_utils.debug_and_log(
'Could not save checkpoint: {}'.format(ex))
if __name__ == '__main__':
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment