Unverified Commit afdf2599 authored by BasiaFusinska's avatar BasiaFusinska Committed by GitHub
Browse files

Moved code from TF1 to TF2 for DELF logging, testing directories and feature...


Moved code from TF1 to TF2 for DELF logging, testing directories and feature extraction scripts (#8591)

* Merged commit includes the following changes:

FolderOrigin-RevId: /google/src/cloud/bfusinska/delf

* Added the import for utils
Co-authored-by: default avatarAndre Araujo <andrearaujo@google.com>
parent 8a13ca4e
...@@ -30,6 +30,7 @@ from delf.python import feature_aggregation_extractor ...@@ -30,6 +30,7 @@ from delf.python import feature_aggregation_extractor
from delf.python import feature_aggregation_similarity from delf.python import feature_aggregation_similarity
from delf.python import feature_extractor from delf.python import feature_extractor
from delf.python import feature_io from delf.python import feature_io
from delf.python import utils
from delf.python.examples import detector from delf.python.examples import detector
from delf.python.examples import extractor from delf.python.examples import extractor
from delf.python import detect_to_retrieve from delf.python import detect_to_retrieve
......
...@@ -67,7 +67,7 @@ class DelfV1(object): ...@@ -67,7 +67,7 @@ class DelfV1(object):
""" """
def __init__(self, target_layer_type=_SUPPORTED_TARGET_LAYER[0]): def __init__(self, target_layer_type=_SUPPORTED_TARGET_LAYER[0]):
tf.compat.v1.logging.info('Creating model %s ', target_layer_type) print('Creating model %s ' % target_layer_type)
self._target_layer_type = target_layer_type self._target_layer_type = target_layer_type
if self._target_layer_type not in _SUPPORTED_TARGET_LAYER: if self._target_layer_type not in _SUPPORTED_TARGET_LAYER:
......
...@@ -33,14 +33,13 @@ import time ...@@ -33,14 +33,13 @@ import time
from absl import app from absl import app
from absl import flags from absl import flags
import numpy as np import numpy as np
from PIL import Image
from PIL import ImageFile
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from delf import delf_config_pb2 from delf import delf_config_pb2
from delf import datum_io from delf import datum_io
from delf import feature_io from delf import feature_io
from delf import utils
from delf.python.detect_to_retrieve import dataset from delf.python.detect_to_retrieve import dataset
from delf import extractor from delf import extractor
...@@ -71,27 +70,10 @@ _DELG_GLOBAL_EXTENSION = '.delg_global' ...@@ -71,27 +70,10 @@ _DELG_GLOBAL_EXTENSION = '.delg_global'
_DELG_LOCAL_EXTENSION = '.delg_local' _DELG_LOCAL_EXTENSION = '.delg_local'
_IMAGE_EXTENSION = '.jpg' _IMAGE_EXTENSION = '.jpg'
# To avoid PIL crashing for truncated (corrupted) images.
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Pace to report extraction log. # Pace to report extraction log.
_STATUS_CHECK_ITERATIONS = 50 _STATUS_CHECK_ITERATIONS = 50
def _PilLoader(path):
"""Helper function to read image with PIL.
Args:
path: Path to image to be loaded.
Returns:
PIL image in RGB format.
"""
with tf.io.gfile.GFile(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def main(argv): def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.') raise RuntimeError('Too many command-line arguments.')
...@@ -155,7 +137,7 @@ def main(argv): ...@@ -155,7 +137,7 @@ def main(argv):
print('Skipping %s' % image_name) print('Skipping %s' % image_name)
continue continue
pil_im = _PilLoader(input_image_filename) pil_im = utils.RgbLoader(input_image_filename)
resize_factor = 1.0 resize_factor = 1.0
if FLAGS.image_set == 'query': if FLAGS.image_set == 'query':
# Crop query image according to bounding box. # Crop query image according to bounding box.
......
...@@ -24,14 +24,13 @@ import os ...@@ -24,14 +24,13 @@ import os
import time import time
import numpy as np import numpy as np
from PIL import Image
from PIL import ImageFile
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from delf import delf_config_pb2 from delf import delf_config_pb2
from delf import box_io from delf import box_io
from delf import feature_io from delf import feature_io
from delf import utils
from delf import detector from delf import detector
from delf import extractor from delf import extractor
...@@ -42,23 +41,6 @@ _DELF_EXTENSION = '.delf' ...@@ -42,23 +41,6 @@ _DELF_EXTENSION = '.delf'
# Pace to report extraction log. # Pace to report extraction log.
_STATUS_CHECK_ITERATIONS = 100 _STATUS_CHECK_ITERATIONS = 100
# To avoid crashing for truncated (corrupted) images.
ImageFile.LOAD_TRUNCATED_IMAGES = True
def _PilLoader(path):
"""Helper function to read image with PIL.
Args:
path: Path to image to be loaded.
Returns:
PIL image in RGB format.
"""
with tf.io.gfile.GFile(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def _WriteMappingBasenameToIds(index_names_ids_and_boxes, output_path): def _WriteMappingBasenameToIds(index_names_ids_and_boxes, output_path):
"""Helper function to write CSV mapping from DELF file name to IDs. """Helper function to write CSV mapping from DELF file name to IDs.
...@@ -157,7 +139,7 @@ def ExtractBoxesAndFeaturesToFiles(image_names, image_paths, delf_config_path, ...@@ -157,7 +139,7 @@ def ExtractBoxesAndFeaturesToFiles(image_names, image_paths, delf_config_path,
output_box_filename = os.path.join(output_boxes_dir, output_box_filename = os.path.join(output_boxes_dir,
image_name + _BOX_EXTENSION) image_name + _BOX_EXTENSION)
pil_im = _PilLoader(image_paths[i]) pil_im = utils.RgbLoader(image_paths[i])
width, height = pil_im.size width, height = pil_im.size
# Extract and save boxes. # Extract and save boxes.
......
...@@ -20,11 +20,14 @@ from __future__ import print_function ...@@ -20,11 +20,14 @@ from __future__ import print_function
import os import os
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from delf.python.detect_to_retrieve import dataset from delf.python.detect_to_retrieve import dataset
FLAGS = flags.FLAGS
class DatasetTest(tf.test.TestCase): class DatasetTest(tf.test.TestCase):
...@@ -206,7 +209,7 @@ class DatasetTest(tf.test.TestCase): ...@@ -206,7 +209,7 @@ class DatasetTest(tf.test.TestCase):
'medium': np.array([0.5, 1.0]) 'medium': np.array([0.5, 1.0])
} }
pr_ranks = [1, 5] pr_ranks = [1, 5]
output_path = os.path.join(tf.compat.v1.test.get_temp_dir(), 'metrics.txt') output_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
# Run tested function. # Run tested function.
dataset.SaveMetricsFile(mean_average_precision, mean_precisions, dataset.SaveMetricsFile(mean_average_precision, mean_precisions,
...@@ -240,7 +243,7 @@ class DatasetTest(tf.test.TestCase): ...@@ -240,7 +243,7 @@ class DatasetTest(tf.test.TestCase):
'medium': np.array([0.5, 1.0]) 'medium': np.array([0.5, 1.0])
} }
pr_ranks = [1, 5] pr_ranks = [1, 5]
output_path = os.path.join(tf.compat.v1.test.get_temp_dir(), 'metrics.txt') output_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
# Run tested functions. # Run tested functions.
dataset.SaveMetricsFile(mean_average_precision, mean_precisions, dataset.SaveMetricsFile(mean_average_precision, mean_precisions,
...@@ -261,7 +264,7 @@ class DatasetTest(tf.test.TestCase): ...@@ -261,7 +264,7 @@ class DatasetTest(tf.test.TestCase):
def testReadMetricsWithRepeatedProtocolFails(self): def testReadMetricsWithRepeatedProtocolFails(self):
# Define inputs. # Define inputs.
input_path = os.path.join(tf.compat.v1.test.get_temp_dir(), 'metrics.txt') input_path = os.path.join(FLAGS.test_tmpdir, 'metrics.txt')
with tf.io.gfile.GFile(input_path, 'w') as f: with tf.io.gfile.GFile(input_path, 'w') as f:
f.write('hard\n' f.write('hard\n'
' mAP=70.0\n' ' mAP=70.0\n'
......
...@@ -31,14 +31,13 @@ import sys ...@@ -31,14 +31,13 @@ import sys
import time import time
import numpy as np import numpy as np
from PIL import Image
from PIL import ImageFile
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.platform import app from tensorflow.python.platform import app
from delf import delf_config_pb2 from delf import delf_config_pb2
from delf import feature_io from delf import feature_io
from delf import utils
from delf.python.detect_to_retrieve import dataset from delf.python.detect_to_retrieve import dataset
from delf import extractor from delf import extractor
...@@ -48,37 +47,17 @@ cmd_args = None ...@@ -48,37 +47,17 @@ cmd_args = None
_DELF_EXTENSION = '.delf' _DELF_EXTENSION = '.delf'
_IMAGE_EXTENSION = '.jpg' _IMAGE_EXTENSION = '.jpg'
# To avoid PIL crashing for truncated (corrupted) images.
ImageFile.LOAD_TRUNCATED_IMAGES = True
def _PilLoader(path):
"""Helper function to read image with PIL.
Args:
path: Path to image to be loaded.
Returns:
PIL image in RGB format.
"""
with tf.io.gfile.GFile(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def main(argv): def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.') raise RuntimeError('Too many command-line arguments.')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
# Read list of query images from dataset file. # Read list of query images from dataset file.
tf.compat.v1.logging.info( print('Reading list of query images and boxes from dataset file...')
'Reading list of query images and boxes from dataset file...')
query_list, _, ground_truth = dataset.ReadDatasetFile( query_list, _, ground_truth = dataset.ReadDatasetFile(
cmd_args.dataset_file_path) cmd_args.dataset_file_path)
num_images = len(query_list) num_images = len(query_list)
tf.compat.v1.logging.info('done! Found %d images', num_images) print(f'done! Found {num_images} images')
# Parse DelfConfig proto. # Parse DelfConfig proto.
config = delf_config_pb2.DelfConfig() config = delf_config_pb2.DelfConfig()
...@@ -104,12 +83,12 @@ def main(argv): ...@@ -104,12 +83,12 @@ def main(argv):
output_feature_filename = os.path.join( output_feature_filename = os.path.join(
cmd_args.output_features_dir, query_image_name + _DELF_EXTENSION) cmd_args.output_features_dir, query_image_name + _DELF_EXTENSION)
if tf.io.gfile.exists(output_feature_filename): if tf.io.gfile.exists(output_feature_filename):
tf.compat.v1.logging.info('Skipping %s', query_image_name) print(f'Skipping {query_image_name}')
continue continue
# Crop query image according to bounding box. # Crop query image according to bounding box.
bbox = [int(round(b)) for b in ground_truth[i]['bbx']] bbox = [int(round(b)) for b in ground_truth[i]['bbx']]
im = np.array(_PilLoader(input_image_filename).crop(bbox)) im = np.array(utils.RgbLoader(input_image_filename).crop(bbox))
# Extract and save features. # Extract and save features.
extracted_features = extractor_fn(im) extracted_features = extractor_fn(im)
......
...@@ -34,6 +34,7 @@ import tensorflow as tf ...@@ -34,6 +34,7 @@ import tensorflow as tf
from tensorflow.python.platform import app from tensorflow.python.platform import app
from delf import box_io from delf import box_io
from delf import utils
from delf import detector from delf import detector
cmd_args = None cmd_args = None
...@@ -130,13 +131,11 @@ def main(argv): ...@@ -130,13 +131,11 @@ def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.') raise RuntimeError('Too many command-line arguments.')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
# Read list of images. # Read list of images.
tf.compat.v1.logging.info('Reading list of images...') print('Reading list of images...')
image_paths = _ReadImageList(cmd_args.list_images_path) image_paths = _ReadImageList(cmd_args.list_images_path)
num_images = len(image_paths) num_images = len(image_paths)
tf.compat.v1.logging.info('done! Found %d images', num_images) print(f'done! Found {num_images} images')
# Create output directories if necessary. # Create output directories if necessary.
if not tf.io.gfile.exists(cmd_args.output_dir): if not tf.io.gfile.exists(cmd_args.output_dir):
...@@ -147,48 +146,36 @@ def main(argv): ...@@ -147,48 +146,36 @@ def main(argv):
# Tell TensorFlow that the model will be built into the default Graph. # Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default(): with tf.Graph().as_default():
# Reading list of images.
filename_queue = tf.compat.v1.train.string_input_producer(
image_paths, shuffle=False)
reader = tf.compat.v1.WholeFileReader()
_, value = reader.read(filename_queue)
image_tf = tf.io.decode_jpeg(value, channels=3)
image_tf = tf.expand_dims(image_tf, 0)
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
init_op = tf.compat.v1.global_variables_initializer() init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op) sess.run(init_op)
detector_fn = detector.MakeDetector(sess, cmd_args.detector_path) detector_fn = detector.MakeDetector(sess, cmd_args.detector_path)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
start = time.clock() start = time.clock()
for i, image_path in enumerate(image_paths): for i, image_path in enumerate(image_paths):
# Write to log-info once in a while. # Write to log-info once in a while.
if i == 0: if i == 0:
tf.compat.v1.logging.info('Starting to detect objects in images...') print('Starting to detect objects in images...')
elif i % _STATUS_CHECK_ITERATIONS == 0: elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.clock() - start) elapsed = (time.clock() - start)
tf.compat.v1.logging.info( print(
'Processing image %d out of %d, last %d ' f'Processing image {i} out of {num_images}, last '
'images took %f seconds', i, num_images, _STATUS_CHECK_ITERATIONS, f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
elapsed) )
start = time.clock() start = time.clock()
# # Get next image.
im = sess.run(image_tf)
# If descriptor already exists, skip its computation. # If descriptor already exists, skip its computation.
base_boxes_filename, _ = os.path.splitext(os.path.basename(image_path)) base_boxes_filename, _ = os.path.splitext(os.path.basename(image_path))
out_boxes_filename = base_boxes_filename + _BOX_EXT out_boxes_filename = base_boxes_filename + _BOX_EXT
out_boxes_fullpath = os.path.join(cmd_args.output_dir, out_boxes_fullpath = os.path.join(cmd_args.output_dir,
out_boxes_filename) out_boxes_filename)
if tf.io.gfile.exists(out_boxes_fullpath): if tf.io.gfile.exists(out_boxes_fullpath):
tf.compat.v1.logging.info('Skipping %s', image_path) print(f'Skipping {image_path}')
continue continue
im = np.expand_dims(np.array(utils.RgbLoader(image_paths[i])), 0)
# Extract and save boxes. # Extract and save boxes.
(boxes_out, scores_out, class_indices_out) = detector_fn(im) (boxes_out, scores_out, class_indices_out) = detector_fn(im)
(selected_boxes, selected_scores, (selected_boxes, selected_scores,
...@@ -205,10 +192,6 @@ def main(argv): ...@@ -205,10 +192,6 @@ def main(argv):
out_viz_filename) out_viz_filename)
_PlotBoxesAndSaveImage(im[0], selected_boxes, out_viz_fullpath) _PlotBoxesAndSaveImage(im[0], selected_boxes, out_viz_fullpath)
# Finalize enqueue threads.
coord.request_stop()
coord.join(threads)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -27,6 +27,7 @@ import os ...@@ -27,6 +27,7 @@ import os
import sys import sys
import time import time
import numpy as np
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
...@@ -34,6 +35,7 @@ from google.protobuf import text_format ...@@ -34,6 +35,7 @@ from google.protobuf import text_format
from tensorflow.python.platform import app from tensorflow.python.platform import app
from delf import delf_config_pb2 from delf import delf_config_pb2
from delf import feature_io from delf import feature_io
from delf import utils
from delf import extractor from delf import extractor
cmd_args = None cmd_args = None
...@@ -61,13 +63,11 @@ def _ReadImageList(list_path): ...@@ -61,13 +63,11 @@ def _ReadImageList(list_path):
def main(unused_argv): def main(unused_argv):
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
# Read list of images. # Read list of images.
tf.compat.v1.logging.info('Reading list of images...') print('Reading list of images...')
image_paths = _ReadImageList(cmd_args.list_images_path) image_paths = _ReadImageList(cmd_args.list_images_path)
num_images = len(image_paths) num_images = len(image_paths)
tf.compat.v1.logging.info('done! Found %d images', num_images) print(f'done! Found {num_images} images')
# Parse DelfConfig proto. # Parse DelfConfig proto.
config = delf_config_pb2.DelfConfig() config = delf_config_pb2.DelfConfig()
...@@ -80,47 +80,35 @@ def main(unused_argv): ...@@ -80,47 +80,35 @@ def main(unused_argv):
# Tell TensorFlow that the model will be built into the default Graph. # Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default(): with tf.Graph().as_default():
# Reading list of images.
filename_queue = tf.compat.v1.train.string_input_producer(
image_paths, shuffle=False)
reader = tf.compat.v1.WholeFileReader()
_, value = reader.read(filename_queue)
image_tf = tf.io.decode_jpeg(value, channels=3)
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
init_op = tf.compat.v1.global_variables_initializer() init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op) sess.run(init_op)
extractor_fn = extractor.MakeExtractor(sess, config) extractor_fn = extractor.MakeExtractor(sess, config)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
start = time.clock() start = time.clock()
for i in range(num_images): for i in range(num_images):
# Write to log-info once in a while. # Write to log-info once in a while.
if i == 0: if i == 0:
tf.compat.v1.logging.info( print('Starting to extract DELF features from images...')
'Starting to extract DELF features from images...')
elif i % _STATUS_CHECK_ITERATIONS == 0: elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.clock() - start) elapsed = (time.clock() - start)
tf.compat.v1.logging.info( print(
'Processing image %d out of %d, last %d ' f'Processing image {i} out of {num_images}, last '
'images took %f seconds', i, num_images, _STATUS_CHECK_ITERATIONS, f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
elapsed) )
start = time.clock() start = time.clock()
# # Get next image.
im = sess.run(image_tf)
# If descriptor already exists, skip its computation. # If descriptor already exists, skip its computation.
out_desc_filename = os.path.splitext(os.path.basename( out_desc_filename = os.path.splitext(os.path.basename(
image_paths[i]))[0] + _DELF_EXT image_paths[i]))[0] + _DELF_EXT
out_desc_fullpath = os.path.join(cmd_args.output_dir, out_desc_filename) out_desc_fullpath = os.path.join(cmd_args.output_dir, out_desc_filename)
if tf.io.gfile.exists(out_desc_fullpath): if tf.io.gfile.exists(out_desc_fullpath):
tf.compat.v1.logging.info('Skipping %s', image_paths[i]) print(f'Skipping {image_paths[i]}')
continue continue
im = np.array(utils.RgbLoader(image_paths[i]))
# Extract and save features. # Extract and save features.
extracted_features = extractor_fn(im) extracted_features = extractor_fn(im)
locations_out = extracted_features['local_features']['locations'] locations_out = extracted_features['local_features']['locations']
...@@ -132,10 +120,6 @@ def main(unused_argv): ...@@ -132,10 +120,6 @@ def main(unused_argv):
feature_scales_out, descriptors_out, feature_scales_out, descriptors_out,
attention_out) attention_out)
# Finalize enqueue threads.
coord.request_stop()
coord.join(threads)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -37,7 +37,6 @@ from scipy import spatial ...@@ -37,7 +37,6 @@ from scipy import spatial
from skimage import feature from skimage import feature
from skimage import measure from skimage import measure
from skimage import transform from skimage import transform
import tensorflow as tf
from tensorflow.python.platform import app from tensorflow.python.platform import app
from delf import feature_io from delf import feature_io
...@@ -48,17 +47,15 @@ _DISTANCE_THRESHOLD = 0.8 ...@@ -48,17 +47,15 @@ _DISTANCE_THRESHOLD = 0.8
def main(unused_argv): def main(unused_argv):
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
# Read features. # Read features.
locations_1, _, descriptors_1, _, _ = feature_io.ReadFromFile( locations_1, _, descriptors_1, _, _ = feature_io.ReadFromFile(
cmd_args.features_1_path) cmd_args.features_1_path)
num_features_1 = locations_1.shape[0] num_features_1 = locations_1.shape[0]
tf.compat.v1.logging.info("Loaded image 1's %d features" % num_features_1) print(f"Loaded image 1's {num_features_1} features")
locations_2, _, descriptors_2, _, _ = feature_io.ReadFromFile( locations_2, _, descriptors_2, _, _ = feature_io.ReadFromFile(
cmd_args.features_2_path) cmd_args.features_2_path)
num_features_2 = locations_2.shape[0] num_features_2 = locations_2.shape[0]
tf.compat.v1.logging.info("Loaded image 2's %d features" % num_features_2) print(f"Loaded image 2's {num_features_2} features")
# Find nearest-neighbor matches using a KD tree. # Find nearest-neighbor matches using a KD tree.
d1_tree = spatial.cKDTree(descriptors_1) d1_tree = spatial.cKDTree(descriptors_1)
...@@ -84,7 +81,7 @@ def main(unused_argv): ...@@ -84,7 +81,7 @@ def main(unused_argv):
residual_threshold=20, residual_threshold=20,
max_trials=1000) max_trials=1000)
tf.compat.v1.logging.info('Found %d inliers' % sum(inliers)) print(f'Found {sum(inliers)} inliers')
# Visualize correspondences, and save to file. # Visualize correspondences, and save to file.
_, ax = plt.subplots() _, ax = plt.subplots()
......
...@@ -20,16 +20,19 @@ from __future__ import print_function ...@@ -20,16 +20,19 @@ from __future__ import print_function
import os import os
from absl import flags
import tensorflow as tf import tensorflow as tf
from delf.python.google_landmarks_dataset import dataset_file_io from delf.python.google_landmarks_dataset import dataset_file_io
FLAGS = flags.FLAGS
class DatasetFileIoTest(tf.test.TestCase): class DatasetFileIoTest(tf.test.TestCase):
def testReadRecognitionSolutionWorks(self): def testReadRecognitionSolutionWorks(self):
# Define inputs. # Define inputs.
file_path = os.path.join(tf.compat.v1.test.get_temp_dir(), file_path = os.path.join(FLAGS.test_tmpdir,
'recognition_solution.csv') 'recognition_solution.csv')
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
f.write('id,landmarks,Usage\n') f.write('id,landmarks,Usage\n')
...@@ -61,7 +64,7 @@ class DatasetFileIoTest(tf.test.TestCase): ...@@ -61,7 +64,7 @@ class DatasetFileIoTest(tf.test.TestCase):
def testReadRetrievalSolutionWorks(self): def testReadRetrievalSolutionWorks(self):
# Define inputs. # Define inputs.
file_path = os.path.join(tf.compat.v1.test.get_temp_dir(), file_path = os.path.join(FLAGS.test_tmpdir,
'retrieval_solution.csv') 'retrieval_solution.csv')
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
f.write('id,images,Usage\n') f.write('id,images,Usage\n')
...@@ -93,7 +96,7 @@ class DatasetFileIoTest(tf.test.TestCase): ...@@ -93,7 +96,7 @@ class DatasetFileIoTest(tf.test.TestCase):
def testReadRecognitionPredictionsWorks(self): def testReadRecognitionPredictionsWorks(self):
# Define inputs. # Define inputs.
file_path = os.path.join(tf.compat.v1.test.get_temp_dir(), file_path = os.path.join(FLAGS.test_tmpdir,
'recognition_predictions.csv') 'recognition_predictions.csv')
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
f.write('id,landmarks\n') f.write('id,landmarks\n')
...@@ -131,7 +134,7 @@ class DatasetFileIoTest(tf.test.TestCase): ...@@ -131,7 +134,7 @@ class DatasetFileIoTest(tf.test.TestCase):
def testReadRetrievalPredictionsWorks(self): def testReadRetrievalPredictionsWorks(self):
# Define inputs. # Define inputs.
file_path = os.path.join(tf.compat.v1.test.get_temp_dir(), file_path = os.path.join(FLAGS.test_tmpdir,
'retrieval_predictions.csv') 'retrieval_predictions.csv')
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
f.write('id,images\n') f.write('id,images\n')
......
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for DELF."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from PIL import Image
from PIL import ImageFile
import tensorflow as tf
# To avoid PIL crashing for truncated (corrupted) images.
ImageFile.LOAD_TRUNCATED_IMAGES = True
def RgbLoader(path):
"""Helper function to read image with PIL.
Args:
path: Path to image to be loaded.
Returns:
PIL image in RGB format.
"""
with tf.io.gfile.GFile(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment