Unverified Commit b3ef7ae9 authored by Dan Anghel's avatar Dan Anghel Committed by GitHub
Browse files

Push to Github of changes to DELF package (#8670)

* First version of working script to download the GLDv2 dataset

* First version of the DEFL package installation script

* First working version of the DELF package installation script

* Fixed feedback from PR review

* Push to Github of changes to the TFRecord data generation script for DELF.

* Merged commit includes the following changes:
315363544  by Andre Araujo:

    Added the generation of TRAIN and VALIDATE splits from the train dataset.

--
314676530  by Andre Araujo:

    Updated script to download GLDv2 images for DELF training.

--
314101235  by Andre Araujo:

    Added newly created module 'utils' to the copybara script.

--
313677085  by Andre Araujo:

    Code migration from TF1 to TF2 for:
    - logging (replaced usage of tf.compat.v1.logging.info)
    - testing directories (replaced usage of tf.compat.v1.test.get_temp_dir())
    - feature/object extraction scripts (replaced usage of tf.compat.v1.train.stri...
parent 11eeb9cb
......@@ -98,12 +98,7 @@ def main(argv):
if not tf.io.gfile.exists(FLAGS.output_features_dir):
tf.io.gfile.makedirs(FLAGS.output_features_dir)
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
# Initialize variables, construct DELG extractor.
init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op)
extractor_fn = extractor.MakeExtractor(sess, config)
extractor_fn = extractor.MakeExtractor(config)
start = time.time()
for i in range(num_images):
......@@ -153,8 +148,7 @@ def main(argv):
extracted_features = extractor_fn(im, resize_factor)
if config.use_global_features:
global_descriptor = extracted_features['global_descriptor']
datum_io.WriteToFile(global_descriptor,
output_global_feature_filename)
datum_io.WriteToFile(global_descriptor, output_global_feature_filename)
if config.use_local_features:
locations = extracted_features['local_features']['locations']
descriptors = extracted_features['local_features']['descriptors']
......
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Times DELF/G extraction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import app
from absl import flags
import numpy as np
from six.moves import range
import tensorflow as tf
from google.protobuf import text_format
from delf import delf_config_pb2
from delf import utils
from delf import extractor
FLAGS = flags.FLAGS
flags.DEFINE_string(
'delf_config_path', '/tmp/delf_config_example.pbtxt',
'Path to DelfConfig proto text file with configuration to be used for DELG '
'extraction. Local features are extracted if use_local_features is True; '
'global features are extracted if use_global_features is True.')
flags.DEFINE_string('list_images_path', '/tmp/list_images.txt',
'Path to list of images whose features will be extracted.')
flags.DEFINE_integer('repeat_per_image', 10,
'Number of times to repeat extraction per image.')
# Pace to report extraction log.
_STATUS_CHECK_ITERATIONS = 100
def _ReadImageList(list_path):
"""Helper function to read image paths.
Args:
list_path: Path to list of images, one image path per line.
Returns:
image_paths: List of image paths.
"""
with tf.io.gfile.GFile(list_path, 'r') as f:
image_paths = f.readlines()
image_paths = [entry.rstrip() for entry in image_paths]
return image_paths
def main(argv):
if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.')
# Read list of images.
print('Reading list of images...')
image_paths = _ReadImageList(FLAGS.list_images_path)
num_images = len(image_paths)
print(f'done! Found {num_images} images')
# Load images in memory.
print('Loading images, %d times per image...' % FLAGS.repeat_per_image)
im_array = []
for filename in image_paths:
im = np.array(utils.RgbLoader(filename))
for _ in range(FLAGS.repeat_per_image):
im_array.append(im)
np.random.shuffle(im_array)
print('done!')
# Parse DelfConfig proto.
config = delf_config_pb2.DelfConfig()
with tf.io.gfile.GFile(FLAGS.delf_config_path, 'r') as f:
text_format.Parse(f.read(), config)
extractor_fn = extractor.MakeExtractor(config)
start = time.time()
for i, im in enumerate(im_array):
if i == 0:
print('Starting to extract DELF features from images...')
elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.time() - start)
print(f'Processing image {i} out of {len(im_array)}, last '
f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds,'
f'ie {elapsed/_STATUS_CHECK_ITERATIONS} secs/image.')
start = time.time()
# Extract and save features.
extracted_features = extractor_fn(im)
if __name__ == '__main__':
app.run(main)
......@@ -112,26 +112,19 @@ def ExtractBoxesAndFeaturesToFiles(image_names, image_paths, delf_config_path,
tf.io.gfile.makedirs(os.path.dirname(output_mapping))
names_ids_and_boxes = []
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
# Initialize variables, construct detector and DELF extractor.
init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op)
detector_fn = detector.MakeDetector(
sess, detector_model_dir, import_scope='detector')
delf_extractor_fn = extractor.MakeExtractor(
sess, config, import_scope='extractor_delf')
start = time.clock()
detector_fn = detector.MakeDetector(detector_model_dir)
delf_extractor_fn = extractor.MakeExtractor(config)
start = time.time()
for i in range(num_images):
if i == 0:
print('Starting to extract features/boxes...')
elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.clock() - start)
elapsed = (time.time() - start)
print('Processing image %d out of %d, last %d '
'images took %f seconds' %
(i, num_images, _STATUS_CHECK_ITERATIONS, elapsed))
start = time.clock()
start = time.time()
image_name = image_names[i]
output_feature_filename_whole_image = os.path.join(
......@@ -203,8 +196,7 @@ def ExtractBoxesAndFeaturesToFiles(image_names, image_paths, delf_config_path,
attention_out = extracted_features['local_features']['attention']
feature_io.WriteToFile(output_feature_filename, locations_out,
feature_scales_out, descriptors_out,
attention_out)
feature_scales_out, descriptors_out, attention_out)
# Save mapping from output DELF name to image id and box id.
_WriteMappingBasenameToIds(names_ids_and_boxes, output_mapping)
......@@ -68,14 +68,9 @@ def main(argv):
if not tf.io.gfile.exists(cmd_args.output_features_dir):
tf.io.gfile.makedirs(cmd_args.output_features_dir)
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
# Initialize variables, construct DELF extractor.
init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op)
extractor_fn = extractor.MakeExtractor(sess, config)
start = time.clock()
extractor_fn = extractor.MakeExtractor(config)
start = time.time()
for i in range(num_images):
query_image_name = query_list[i]
input_image_filename = os.path.join(cmd_args.images_dir,
......@@ -101,7 +96,7 @@ def main(argv):
feature_scales_out, descriptors_out,
attention_out)
elapsed = (time.clock() - start)
elapsed = (time.time() - start)
print('Processed %d query images in %f seconds' % (num_images, elapsed))
......
......@@ -21,30 +21,22 @@ from __future__ import print_function
import tensorflow as tf
def MakeDetector(sess, model_dir, import_scope=None):
def MakeDetector(model_dir):
"""Creates a function to detect objects in an image.
Args:
sess: TensorFlow session to use.
model_dir: Directory where SavedModel is located.
import_scope: Optional scope to use for model.
Returns:
Function that receives an image and returns detection results.
"""
tf.compat.v1.saved_model.loader.load(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
model_dir,
import_scope=import_scope)
import_scope_prefix = import_scope + '/' if import_scope is not None else ''
input_images = sess.graph.get_tensor_by_name('%sinput_images:0' %
import_scope_prefix)
boxes = sess.graph.get_tensor_by_name('%sdetection_boxes:0' %
import_scope_prefix)
scores = sess.graph.get_tensor_by_name('%sdetection_scores:0' %
import_scope_prefix)
class_indices = sess.graph.get_tensor_by_name('%sdetection_classes:0' %
import_scope_prefix)
model = tf.saved_model.load(model_dir)
# Input and output tensors.
feeds = ['input_images:0']
fetches = ['detection_boxes:0', 'detection_scores:0', 'detection_classes:0']
model = model.prune(feeds=feeds, fetches=fetches)
def DetectorFn(images):
"""Receives an image and returns detected boxes.
......@@ -56,7 +48,8 @@ def MakeDetector(sess, model_dir, import_scope=None):
Returns:
Tuple (boxes, scores, class_indices).
"""
return sess.run([boxes, scores, class_indices],
feed_dict={input_images: images})
boxes, scores, class_indices = model(tf.convert_to_tensor(images))
return boxes.numpy(), scores.numpy(), class_indices.numpy()
return DetectorFn
......@@ -144,26 +144,20 @@ def main(argv):
cmd_args.output_viz_dir):
tf.io.gfile.makedirs(cmd_args.output_viz_dir)
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op)
detector_fn = detector.MakeDetector(cmd_args.detector_path)
detector_fn = detector.MakeDetector(sess, cmd_args.detector_path)
start = time.clock()
start = time.time()
for i, image_path in enumerate(image_paths):
# Write to log-info once in a while.
# Report progress once in a while.
if i == 0:
print('Starting to detect objects in images...')
elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.clock() - start)
elapsed = (time.time() - start)
print(
f'Processing image {i} out of {num_images}, last '
f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
)
start = time.clock()
start = time.time()
# If descriptor already exists, skip its computation.
base_boxes_filename, _ = os.path.splitext(os.path.basename(image_path))
......
......@@ -78,26 +78,20 @@ def main(unused_argv):
if not tf.io.gfile.exists(cmd_args.output_dir):
tf.io.gfile.makedirs(cmd_args.output_dir)
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
init_op = tf.compat.v1.global_variables_initializer()
sess.run(init_op)
extractor_fn = extractor.MakeExtractor(config)
extractor_fn = extractor.MakeExtractor(sess, config)
start = time.clock()
start = time.time()
for i in range(num_images):
# Write to log-info once in a while.
# Report progress once in a while.
if i == 0:
print('Starting to extract DELF features from images...')
elif i % _STATUS_CHECK_ITERATIONS == 0:
elapsed = (time.clock() - start)
elapsed = (time.time() - start)
print(
f'Processing image {i} out of {num_images}, last '
f'{_STATUS_CHECK_ITERATIONS} images took {elapsed} seconds'
)
start = time.clock()
start = time.time()
# If descriptor already exists, skip its computation.
out_desc_filename = os.path.splitext(os.path.basename(
......@@ -116,9 +110,8 @@ def main(unused_argv):
feature_scales_out = extracted_features['local_features']['scales']
attention_out = extracted_features['local_features']['attention']
feature_io.WriteToFile(out_desc_fullpath, locations_out,
feature_scales_out, descriptors_out,
attention_out)
feature_io.WriteToFile(out_desc_fullpath, locations_out, feature_scales_out,
descriptors_out, attention_out)
if __name__ == '__main__':
......
......@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image
import tensorflow as tf
from delf import datum_io
from delf import feature_extractor
# Minimum dimensions below which DELF features are not extracted (empty
......@@ -93,70 +94,91 @@ def ResizeImage(image, config, resize_factor=1.0):
return resized_image, scale_factors
def MakeExtractor(sess, config, import_scope=None):
def MakeExtractor(config):
"""Creates a function to extract global and/or local features from an image.
Args:
sess: TensorFlow session to use.
config: DelfConfig proto containing the model configuration.
import_scope: Optional scope to use for model.
Returns:
Function that receives an image and returns features.
"""
# Load model.
tf.compat.v1.saved_model.loader.load(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
config.model_path,
import_scope=import_scope)
import_scope_prefix = import_scope + '/' if import_scope is not None else ''
model = tf.saved_model.load(config.model_path)
# Input tensors.
input_image = sess.graph.get_tensor_by_name('%sinput_image:0' %
import_scope_prefix)
input_image_scales = sess.graph.get_tensor_by_name('%sinput_scales:0' %
import_scope_prefix)
if config.use_local_features:
input_score_threshold = sess.graph.get_tensor_by_name(
'%sinput_abs_thres:0' % import_scope_prefix)
input_max_feature_num = sess.graph.get_tensor_by_name(
'%sinput_max_feature_num:0' % import_scope_prefix)
# Input/output end-points/tensors.
feeds = ['input_image:0', 'input_scales:0']
fetches = []
image_scales_tensor = tf.convert_to_tensor(list(config.image_scales))
# Output tensors.
if config.use_global_features:
raw_global_descriptors = sess.graph.get_tensor_by_name(
'%sglobal_descriptors:0' % import_scope_prefix)
# Custom configuration needed when local features are used.
if config.use_local_features:
boxes = sess.graph.get_tensor_by_name('%sboxes:0' % import_scope_prefix)
raw_local_descriptors = sess.graph.get_tensor_by_name('%sfeatures:0' %
import_scope_prefix)
feature_scales = sess.graph.get_tensor_by_name('%sscales:0' %
import_scope_prefix)
attention_with_extra_dim = sess.graph.get_tensor_by_name(
'%sscores:0' % import_scope_prefix)
# Extra input/output end-points/tensors.
feeds.append('input_abs_thres:0')
feeds.append('input_max_feature_num:0')
fetches.append('boxes:0')
fetches.append('features:0')
fetches.append('scales:0')
fetches.append('scores:0')
score_threshold_tensor = tf.constant(
config.delf_local_config.score_threshold)
max_feature_num_tensor = tf.constant(
config.delf_local_config.max_feature_num)
# If using PCA, pre-load required parameters.
local_pca_parameters = {}
if config.delf_local_config.use_pca:
local_pca_parameters['mean'] = tf.constant(
datum_io.ReadFromFile(
config.delf_local_config.pca_parameters.mean_path),
dtype=tf.float32)
local_pca_parameters['matrix'] = tf.constant(
datum_io.ReadFromFile(
config.delf_local_config.pca_parameters.projection_matrix_path),
dtype=tf.float32)
local_pca_parameters[
'dim'] = config.delf_local_config.pca_parameters.pca_dim
local_pca_parameters['use_whitening'] = (
config.delf_local_config.pca_parameters.use_whitening)
if config.delf_local_config.pca_parameters.use_whitening:
local_pca_parameters['variances'] = tf.squeeze(
tf.constant(
datum_io.ReadFromFile(
config.delf_local_config.pca_parameters.pca_variances_path),
dtype=tf.float32))
else:
local_pca_parameters['variances'] = None
# Post-process extracted features: normalize, PCA (optional), pooling.
# Custom configuration needed when global features are used.
if config.use_global_features:
if config.delf_global_config.image_scales_ind:
raw_global_descriptors_selected_scales = tf.gather(
raw_global_descriptors,
list(config.delf_global_config.image_scales_ind))
# Extra output end-point.
fetches.append('global_descriptors:0')
# If using PCA, pre-load required parameters.
global_pca_parameters = {}
if config.delf_global_config.use_pca:
global_pca_parameters['mean'] = tf.constant(
datum_io.ReadFromFile(
config.delf_global_config.pca_parameters.mean_path),
dtype=tf.float32)
global_pca_parameters['matrix'] = tf.constant(
datum_io.ReadFromFile(
config.delf_global_config.pca_parameters.projection_matrix_path),
dtype=tf.float32)
global_pca_parameters[
'dim'] = config.delf_global_config.pca_parameters.pca_dim
global_pca_parameters['use_whitening'] = (
config.delf_global_config.pca_parameters.use_whitening)
if config.delf_global_config.pca_parameters.use_whitening:
global_pca_parameters['variances'] = tf.squeeze(
tf.constant(
datum_io.ReadFromFile(config.delf_global_config.pca_parameters
.pca_variances_path),
dtype=tf.float32))
else:
raw_global_descriptors_selected_scales = raw_global_descriptors
global_descriptors_per_scale = feature_extractor.PostProcessDescriptors(
raw_global_descriptors_selected_scales,
config.delf_global_config.use_pca,
config.delf_global_config.pca_parameters)
unnormalized_global_descriptor = tf.reduce_sum(
global_descriptors_per_scale, axis=0, name='sum_pooling')
global_descriptor = tf.nn.l2_normalize(
unnormalized_global_descriptor, axis=0, name='final_l2_normalization')
global_pca_parameters['variances'] = None
if config.use_local_features:
attention = tf.reshape(attention_with_extra_dim,
[tf.shape(attention_with_extra_dim)[0]])
locations, local_descriptors = feature_extractor.DelfFeaturePostProcessing(
boxes, raw_local_descriptors, config)
model = model.prune(feeds=feeds, fetches=fetches)
def ExtractorFn(image, resize_factor=1.0):
"""Receives an image and returns DELF global and/or local features.
......@@ -194,35 +216,62 @@ def MakeExtractor(sess, config, import_scope=None):
})
return extracted_features
feed_dict = {
input_image: resized_image,
input_image_scales: list(config.image_scales),
}
fetches = {}
# Input tensors.
image_tensor = tf.convert_to_tensor(resized_image)
# Extracted features.
extracted_features = {}
output = None
if config.use_local_features:
output = model(image_tensor, image_scales_tensor, score_threshold_tensor,
max_feature_num_tensor)
else:
output = model(image_tensor, image_scales_tensor)
# Post-process extracted features: normalize, PCA (optional), pooling.
if config.use_global_features:
fetches.update({
'global_descriptor': global_descriptor,
raw_global_descriptors = output[-1]
if config.delf_global_config.image_scales_ind:
raw_global_descriptors_selected_scales = tf.gather(
raw_global_descriptors,
list(config.delf_global_config.image_scales_ind))
else:
raw_global_descriptors_selected_scales = raw_global_descriptors
global_descriptors_per_scale = feature_extractor.PostProcessDescriptors(
raw_global_descriptors_selected_scales,
config.delf_global_config.use_pca, global_pca_parameters)
unnormalized_global_descriptor = tf.reduce_sum(
global_descriptors_per_scale, axis=0, name='sum_pooling')
global_descriptor = tf.nn.l2_normalize(
unnormalized_global_descriptor, axis=0, name='final_l2_normalization')
extracted_features.update({
'global_descriptor': global_descriptor.numpy(),
})
if config.use_local_features:
feed_dict.update({
input_score_threshold: config.delf_local_config.score_threshold,
input_max_feature_num: config.delf_local_config.max_feature_num,
})
fetches.update({
boxes = output[0]
raw_local_descriptors = output[1]
feature_scales = output[2]
attention_with_extra_dim = output[3]
attention = tf.reshape(attention_with_extra_dim,
[tf.shape(attention_with_extra_dim)[0]])
locations, local_descriptors = (
feature_extractor.DelfFeaturePostProcessing(
boxes, raw_local_descriptors, config.delf_local_config.use_pca,
local_pca_parameters))
locations /= scale_factors
extracted_features.update({
'local_features': {
'locations': locations,
'descriptors': local_descriptors,
'scales': feature_scales,
'attention': attention,
'locations': locations.numpy(),
'descriptors': local_descriptors.numpy(),
'scales': feature_scales.numpy(),
'attention': attention.numpy(),
}
})
extracted_features = sess.run(fetches, feed_dict=feed_dict)
# Adjust local feature positions due to rescaling.
if config.use_local_features:
extracted_features['local_features']['locations'] /= scale_factors
return extracted_features
return ExtractorFn
......@@ -19,7 +19,6 @@ from __future__ import print_function
import tensorflow as tf
from delf import datum_io
from delf import delf_v1
from object_detection.core import box_list
from object_detection.core import box_list_ops
......@@ -331,13 +330,15 @@ def ApplyPcaAndWhitening(data,
return output
def PostProcessDescriptors(descriptors, use_pca, pca_parameters):
def PostProcessDescriptors(descriptors, use_pca, pca_parameters=None):
"""Post-process descriptors.
Args:
descriptors: [N, input_dim] float tensor.
use_pca: Whether to use PCA.
pca_parameters: DelfPcaParameters proto.
pca_parameters: Only used if `use_pca` is True. Dict containing PCA
parameter tensors, with keys 'mean', 'matrix', 'dim', 'use_whitening',
'variances'.
Returns:
final_descriptors: [N, output_dim] float tensor with descriptors after
......@@ -349,25 +350,13 @@ def PostProcessDescriptors(descriptors, use_pca, pca_parameters):
descriptors, axis=1, name='l2_normalization')
if use_pca:
# Load PCA parameters.
pca_mean = tf.constant(
datum_io.ReadFromFile(pca_parameters.mean_path), dtype=tf.float32)
pca_matrix = tf.constant(
datum_io.ReadFromFile(pca_parameters.projection_matrix_path),
dtype=tf.float32)
pca_dim = pca_parameters.pca_dim
pca_variances = None
if pca_parameters.use_whitening:
pca_variances = tf.squeeze(
tf.constant(
datum_io.ReadFromFile(pca_parameters.pca_variances_path),
dtype=tf.float32))
# Apply PCA, and whitening if desired.
final_descriptors = ApplyPcaAndWhitening(final_descriptors, pca_matrix,
pca_mean, pca_dim,
pca_parameters.use_whitening,
pca_variances)
final_descriptors = ApplyPcaAndWhitening(final_descriptors,
pca_parameters['matrix'],
pca_parameters['mean'],
pca_parameters['dim'],
pca_parameters['use_whitening'],
pca_parameters['variances'])
# Re-normalize.
final_descriptors = tf.nn.l2_normalize(
......@@ -376,7 +365,7 @@ def PostProcessDescriptors(descriptors, use_pca, pca_parameters):
return final_descriptors
def DelfFeaturePostProcessing(boxes, descriptors, config):
def DelfFeaturePostProcessing(boxes, descriptors, use_pca, pca_parameters=None):
"""Extract DELF features from input image.
Args:
......@@ -384,7 +373,10 @@ def DelfFeaturePostProcessing(boxes, descriptors, config):
the number of final feature points which pass through keypoint selection
and NMS steps.
descriptors: [N, input_dim] float tensor.
config: DelfConfig proto with DELF extraction options.
use_pca: Whether to use PCA.
pca_parameters: Only used if `use_pca` is True. Dict containing PCA
parameter tensors, with keys 'mean', 'matrix', 'dim', 'use_whitening',
'variances'.
Returns:
locations: [N, 2] float tensor which denotes the selected keypoint
......@@ -395,8 +387,7 @@ def DelfFeaturePostProcessing(boxes, descriptors, config):
# Get center of descriptor boxes, corresponding to feature locations.
locations = CalculateKeypointCenters(boxes)
final_descriptors = PostProcessDescriptors(
descriptors, config.delf_local_config.use_pca,
config.delf_local_config.pca_parameters)
final_descriptors = PostProcessDescriptors(descriptors, use_pca,
pca_parameters)
return locations, final_descriptors
......@@ -345,7 +345,10 @@ def _build_train_and_validation_splits(image_paths, file_ids, labels,
# Create subsets of image attributes by label, shuffle them separately and
# split each subset into TRAIN and VALIDATION splits based on the size of the
# validation split.
splits = {}
splits = {
_VALIDATION_SPLIT: [],
_TRAIN_SPLIT: []
}
rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(seed)))
for label, indexes in image_attrs_idx_by_label.items():
# Create the subset for the current label.
......@@ -355,23 +358,18 @@ def _build_train_and_validation_splits(image_paths, file_ids, labels,
columns_indices = np.arange(images_per_label)
rs.shuffle(columns_indices)
image_attrs_label = image_attrs_label[:, columns_indices]
# Split the current label subset into TRAIN and VALIDATION splits.
# Split the current label subset into TRAIN and VALIDATION splits and add
# each split to the list of all splits.
cutoff_idx = max(1, int(validation_split_size * images_per_label))
validation_split = image_attrs_label[:, 0 : cutoff_idx]
train_split = image_attrs_label[:, cutoff_idx : ]
# Merge the splits of the current subset with the splits of other labels.
splits[_VALIDATION_SPLIT] = (
np.concatenate((splits[_VALIDATION_SPLIT], validation_split), axis=1)
if _VALIDATION_SPLIT in splits else validation_split)
splits[_TRAIN_SPLIT] = (
np.concatenate((splits[_TRAIN_SPLIT], train_split), axis=1)
if _TRAIN_SPLIT in splits else train_split)
splits[_VALIDATION_SPLIT].append(image_attrs_label[:, 0 : cutoff_idx])
splits[_TRAIN_SPLIT].append(image_attrs_label[:, cutoff_idx : ])
validation_split = np.concatenate(splits[_VALIDATION_SPLIT], axis=1)
train_split = np.concatenate(splits[_TRAIN_SPLIT], axis=1)
# Unstack the image attribute arrays in the TRAIN and VALIDATION splits and
# convert them back to lists. Convert labels back to 'int' from 'str'
# following the explicit type change from 'str' to 'int' for stacking.
validation_split = splits[_VALIDATION_SPLIT]
train_split = splits[_TRAIN_SPLIT]
return (
{
_IMAGE_PATHS_KEY: validation_split[0, :].tolist(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment