Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
78c43ef1
Commit
78c43ef1
authored
Jul 26, 2021
by
Gunho Park
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
67cfc95b
e3c7e300
Changes
227
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1428 additions
and
40 deletions
+1428
-40
official/vision/beta/serving/video_classification_test.py
official/vision/beta/serving/video_classification_test.py
+114
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+2
-6
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+2
-6
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+2
-6
official/vision/beta/train.py
official/vision/beta/train.py
+1
-0
official/vision/beta/train_spatial_partitioning.py
official/vision/beta/train_spatial_partitioning.py
+22
-6
official/vision/image_classification/README.md
official/vision/image_classification/README.md
+3
-0
official/vision/keras_cv/ops/iou_similarity.py
official/vision/keras_cv/ops/iou_similarity.py
+3
-0
orbit/actions/export_saved_model.py
orbit/actions/export_saved_model.py
+18
-16
orbit/actions/export_saved_model_test.py
orbit/actions/export_saved_model_test.py
+17
-0
research/delf/delf/python/datasets/generic_dataset.py
research/delf/delf/python/datasets/generic_dataset.py
+81
-0
research/delf/delf/python/datasets/generic_dataset_test.py
research/delf/delf/python/datasets/generic_dataset_test.py
+60
-0
research/delf/delf/python/datasets/sfm120k/__init__.py
research/delf/delf/python/datasets/sfm120k/__init__.py
+23
-0
research/delf/delf/python/datasets/sfm120k/dataset_download.py
...rch/delf/delf/python/datasets/sfm120k/dataset_download.py
+103
-0
research/delf/delf/python/datasets/sfm120k/sfm120k.py
research/delf/delf/python/datasets/sfm120k/sfm120k.py
+143
-0
research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
+37
-0
research/delf/delf/python/datasets/tuples_dataset.py
research/delf/delf/python/datasets/tuples_dataset.py
+328
-0
research/delf/delf/python/datasets/tuples_dataset_test.py
research/delf/delf/python/datasets/tuples_dataset_test.py
+88
-0
research/delf/delf/python/training/global_features/__init__.py
...rch/delf/delf/python/training/global_features/__init__.py
+19
-0
research/delf/delf/python/training/global_features/train.py
research/delf/delf/python/training/global_features/train.py
+362
-0
No files found.
official/vision/beta/serving/video_classification_test.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# import io
import
os
import
random
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
video_classification
class
VideoClassificationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
):
params
=
exp_factory
.
get_exp_config
(
'video_classification_ucf101'
)
params
.
task
.
train_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
validation_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
model
.
backbone
.
resnet_3d
.
model_id
=
50
classification_module
=
video_classification
.
VideoClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
8
,
64
,
64
])
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
module
=
None
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
images
=
np
.
random
.
randint
(
low
=
0
,
high
=
255
,
size
=
(
1
,
8
,
64
,
64
,
3
),
dtype
=
np
.
uint8
)
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
return
images
,
images
elif
input_type
==
'tf_example'
:
example
=
tfexample_utils
.
make_video_test_example
(
image_shape
=
(
64
,
64
,
3
),
audio_shape
=
(
20
,
128
),
label
=
random
.
randint
(
0
,
100
)).
SerializeToString
()
images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_decode_tf_example
,
elems
=
tf
.
constant
([
example
]),
fn_output_signature
=
{
video_classification
.
video_input
.
IMAGE_KEY
:
tf
.
string
,
}))
images
=
images
[
video_classification
.
video_input
.
IMAGE_KEY
]
return
[
example
],
images
else
:
raise
ValueError
(
f
'
{
input_type
}
'
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
,
images_tensor
=
self
.
_get_dummy_input
(
input_type
,
module
)
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_preprocess_image
,
elems
=
images_tensor
,
fn_output_signature
=
{
'image'
:
tf
.
float32
,
}))
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
())
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
78c43ef1
...
@@ -24,7 +24,7 @@ from official.modeling import tf_utils
...
@@ -24,7 +24,7 @@ from official.modeling import tf_utils
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_
classification_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_classification_decoder
(
params
.
tfds_name
)
decoder
=
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
classification_input
.
Decoder
(
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
...
...
official/vision/beta/tasks/retinanet.py
View file @
78c43ef1
...
@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
...
@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tfds_
detection_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
"""Build input dataset."""
"""Build input dataset."""
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
decoder
=
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder_cfg
=
params
.
decoder
.
get
()
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
if
params
.
decoder
.
type
==
'simple_decoder'
:
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
78c43ef1
...
@@ -23,7 +23,7 @@ from official.core import task_factory
...
@@ -23,7 +23,7 @@ from official.core import task_factory
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
tfds_
segmentation_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
decoder
=
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
segmentation_input
.
Decoder
()
decoder
=
segmentation_input
.
Decoder
()
...
...
official/vision/beta/train.py
View file @
78c43ef1
...
@@ -66,4 +66,5 @@ def main(_):
...
@@ -66,4 +66,5 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
app
.
run
(
main
)
official/vision/beta/train_spatial_partitioning.py
View file @
78c43ef1
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# Lint as: python3
# Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
...
@@ -33,19 +34,34 @@ from official.modeling import performance
...
@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
):
def
get_computation_shape_for_model_parallelism
(
"""Return computation shape to be used for TPUStrategy spatial partition."""
input_partition_dims
:
Sequence
[
int
])
->
Sequence
[
int
]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
return
[
1
,
1
,
1
,
1
]
if
num_logical_devices
==
2
:
el
if
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
return
[
1
,
1
,
1
,
2
]
if
num_logical_devices
==
4
:
el
if
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
return
[
1
,
2
,
1
,
2
]
if
num_logical_devices
==
8
:
el
if
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
return
[
2
,
2
,
1
,
2
]
if
num_logical_devices
==
16
:
el
if
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
return
[
4
,
2
,
1
,
2
]
else
:
raise
ValueError
(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16'
%
num_logical_devices
)
def
create_distribution_strategy
(
distribution_strategy
,
def
create_distribution_strategy
(
distribution_strategy
,
...
...
official/vision/image_classification/README.md
View file @
78c43ef1
# Image Classification
# Image Classification
**Warning:**
the features in the
`image_classification/`
folder have been fully
intergrated into vision/beta. Please use the
[
new code base
](
../beta/README.md
)
.
This folder contains TF 2.0 model examples for image classification:
This folder contains TF 2.0 model examples for image classification:
*
[
MNIST
](
#mnist
)
*
[
MNIST
](
#mnist
)
...
...
official/vision/keras_cv/ops/iou_similarity.py
View file @
78c43ef1
...
@@ -132,6 +132,9 @@ class IouSimilarity:
...
@@ -132,6 +132,9 @@ class IouSimilarity:
Output shape:
Output shape:
[M, N], or [B, M, N]
[M, N], or [B, M, N]
"""
"""
boxes_1
=
tf
.
cast
(
boxes_1
,
tf
.
float32
)
boxes_2
=
tf
.
cast
(
boxes_2
,
tf
.
float32
)
boxes_1_rank
=
len
(
boxes_1
.
shape
)
boxes_1_rank
=
len
(
boxes_1
.
shape
)
boxes_2_rank
=
len
(
boxes_2
.
shape
)
boxes_2_rank
=
len
(
boxes_2
.
shape
)
if
boxes_1_rank
<
2
or
boxes_1_rank
>
3
:
if
boxes_1_rank
<
2
or
boxes_1_rank
>
3
:
...
...
orbit/actions/export_saved_model.py
View file @
78c43ef1
...
@@ -14,24 +14,32 @@
...
@@ -14,24 +14,32 @@
"""Provides the `ExportSavedModel` action and associated helper classes."""
"""Provides the `ExportSavedModel` action and associated helper classes."""
import
re
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
def
_id_key
(
filename
):
_
,
id_num
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
def
_find_managed_files
(
base_name
):
r
"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex
=
re
.
compile
(
rf
'
{
re
.
escape
(
base_name
)
}
-\d+$'
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
filenames
=
filter
(
managed_file_regex
.
match
,
filenames
)
return
sorted
(
filenames
,
key
=
_id_key
)
class
_CounterIdFn
:
class
_CounterIdFn
:
"""Implements a counter-based ID function for `ExportFileManager`."""
"""Implements a counter-based ID function for `ExportFileManager`."""
def
__init__
(
self
,
base_name
:
str
):
def
__init__
(
self
,
base_name
:
str
):
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
managed_files
=
_find_managed_files
(
base_name
)
max_counter
=
-
1
self
.
value
=
_id_key
(
managed_files
[
-
1
])
+
1
if
managed_files
else
0
for
filename
in
filenames
:
try
:
_
,
file_number
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
max_counter
=
max
(
max_counter
,
int
(
file_number
))
except
ValueError
:
continue
self
.
value
=
max_counter
+
1
def
__call__
(
self
):
def
__call__
(
self
):
output
=
self
.
value
output
=
self
.
value
...
@@ -82,13 +90,7 @@ class ExportFileManager:
...
@@ -82,13 +90,7 @@ class ExportFileManager:
`ExportFileManager` instance, sorted in increasing integer order of the
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
IDs returned by `next_id_fn`.
"""
"""
return
_find_managed_files
(
self
.
_base_name
)
def
id_key
(
name
):
_
,
id_num
=
name
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
self
.
_base_name
}
-*'
)
return
sorted
(
filenames
,
key
=
id_key
)
def
clean_up
(
self
):
def
clean_up
(
self
):
"""Cleans up old files matching `{base_name}-*`.
"""Cleans up old files matching `{base_name}-*`.
...
...
orbit/actions/export_saved_model_test.py
View file @
78c43ef1
...
@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
_id_sorted_file_base_names
(
directory
.
full_path
),
_id_sorted_file_base_names
(
directory
.
full_path
),
[
'basename-200'
,
'basename-1000'
])
[
'basename-200'
,
'basename-1000'
])
def
test_export_file_manager_managed_files
(
self
):
directory
=
self
.
create_tempdir
()
directory
.
create_file
(
'basename-5'
)
directory
.
create_file
(
'basename-10'
)
directory
.
create_file
(
'basename-50'
)
directory
.
create_file
(
'basename-1000'
)
directory
.
create_file
(
'basename-9'
)
directory
.
create_file
(
'basename-10-suffix'
)
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
manager
=
actions
.
ExportFileManager
(
base_name
,
max_to_keep
=
3
)
self
.
assertLen
(
manager
.
managed_files
,
5
)
self
.
assertEqual
(
manager
.
next_name
(),
f
'
{
base_name
}
-1001'
)
manager
.
clean_up
()
self
.
assertEqual
(
manager
.
managed_files
,
[
f
'
{
base_name
}
-10'
,
f
'
{
base_name
}
-50'
,
f
'
{
base_name
}
-1000'
])
def
test_export_saved_model
(
self
):
def
test_export_saved_model
(
self
):
directory
=
self
.
create_tempdir
()
directory
=
self
.
create_tempdir
()
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
...
...
research/delf/delf/python/datasets/generic_dataset.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for generic image dataset creation."""
import
os
from
delf.python.datasets
import
utils
class
ImagesFromList
():
"""A generic data loader that loads images from a list.
Supports images of different sizes.
"""
def
__init__
(
self
,
root
,
image_paths
,
imsize
=
None
,
bounding_boxes
=
None
,
loader
=
utils
.
default_loader
):
"""ImagesFromList object initialization.
Args:
root: String, root directory path.
image_paths: List, relative image paths as strings.
imsize: Integer, defines the maximum size of longer image side.
bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images.
loader: Callable, a function to load an image given its path.
Raises:
ValueError: Raised if `image_paths` list is empty.
"""
# List of the full image filenames.
images_filenames
=
[
os
.
path
.
join
(
root
,
image_path
)
for
image_path
in
image_paths
]
if
not
images_filenames
:
raise
ValueError
(
"Dataset contains 0 images."
)
self
.
root
=
root
self
.
images
=
image_paths
self
.
imsize
=
imsize
self
.
images_filenames
=
images_filenames
self
.
bounding_boxes
=
bounding_boxes
self
.
loader
=
loader
def
__getitem__
(
self
,
index
):
"""Called to load an image at the given `index`.
Args:
index: Integer, image index.
Returns:
image: Tensor, loaded image.
"""
path
=
self
.
images_filenames
[
index
]
if
self
.
bounding_boxes
is
not
None
:
img
=
self
.
loader
(
path
,
self
.
imsize
,
self
.
bounding_boxes
[
index
])
else
:
img
=
self
.
loader
(
path
,
self
.
imsize
)
return
img
def
__len__
(
self
):
"""Implements the built-in function len().
Returns:
len: Number of images in the dataset.
"""
return
len
(
self
.
images_filenames
)
research/delf/delf/python/datasets/generic_dataset_test.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for generic dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
flags
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
delf.python.datasets
import
generic_dataset
FLAGS
=
flags
.
FLAGS
class
GenericDatasetTest
(
tf
.
test
.
TestCase
):
"""Test functions for generic dataset."""
def
testGenericDataset
(
self
):
"""Tests loading dummy images from list."""
# Number of images to be created.
n
=
2
image_names
=
[]
# Create and save `n` dummy images.
for
i
in
range
(
n
):
dummy_image
=
np
.
random
.
rand
(
1024
,
750
,
3
)
*
255
img_out
=
Image
.
fromarray
(
dummy_image
.
astype
(
'uint8'
)).
convert
(
'RGB'
)
filename
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'test_image_{}.jpg'
.
format
(
i
))
img_out
.
save
(
filename
)
image_names
.
append
(
'test_image_{}.jpg'
.
format
(
i
))
data
=
generic_dataset
.
ImagesFromList
(
root
=
FLAGS
.
test_tmpdir
,
image_paths
=
image_names
,
imsize
=
1024
)
self
.
assertLen
(
data
,
n
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/delf/delf/python/datasets/sfm120k/__init__.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module exposing Sfm120k dataset for training."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# pylint: disable=unused-import
from
delf.python.datasets.sfm120k
import
sfm120k
# pylint: enable=unused-import
research/delf/delf/python/datasets/sfm120k/dataset_download.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) download function."""
import
os
import
tensorflow
as
tf
def
download_train
(
data_dir
):
"""Checks, and, if required, downloads the necessary files for the training.
Checks if the data necessary for running the example training script exist.
If not, it downloads it in the following folder structure:
DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db
files.
DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db
files.
"""
# Create data folder if does not exist.
if
not
tf
.
io
.
gfile
.
exists
(
data_dir
):
tf
.
io
.
gfile
.
mkdir
(
data_dir
)
# Create datasets folder if does not exist.
datasets_dir
=
os
.
path
.
join
(
data_dir
,
'train'
)
if
not
tf
.
io
.
gfile
.
exists
(
datasets_dir
):
tf
.
io
.
gfile
.
mkdir
(
datasets_dir
)
# Download folder train/retrieval-SfM-120k/.
src_dir
=
'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/ims'
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-120k'
,
'ims'
)
download_file
=
'ims.tar.gz'
if
not
tf
.
io
.
gfile
.
exists
(
dst_dir
):
src_file
=
os
.
path
.
join
(
src_dir
,
download_file
)
dst_file
=
os
.
path
.
join
(
dst_dir
,
download_file
)
print
(
'>> Image directory does not exist. Creating: {}'
.
format
(
dst_dir
))
tf
.
io
.
gfile
.
makedirs
(
dst_dir
)
print
(
'>> Downloading ims.tar.gz...'
)
os
.
system
(
'wget {} -O {}'
.
format
(
src_file
,
dst_file
))
print
(
'>> Extracting {}...'
.
format
(
dst_file
))
os
.
system
(
'tar -zxf {} -C {}'
.
format
(
dst_file
,
dst_dir
))
print
(
'>> Extracted, deleting {}...'
.
format
(
dst_file
))
os
.
system
(
'rm {}'
.
format
(
dst_file
))
# Create symlink for train/retrieval-SfM-30k/.
dst_dir_old
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-120k'
,
'ims'
)
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-30k'
,
'ims'
)
if
not
(
tf
.
io
.
gfile
.
exists
(
dst_dir
)
or
os
.
path
.
islink
(
dst_dir
)):
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-30k'
))
os
.
system
(
'ln -s {} {}'
.
format
(
dst_dir_old
,
dst_dir
))
print
(
'>> Created symbolic link from retrieval-SfM-120k/ims to '
'retrieval-SfM-30k/ims'
)
# Download db files.
src_dir
=
'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/dbs'
datasets
=
[
'retrieval-SfM-120k'
,
'retrieval-SfM-30k'
]
for
dataset
in
datasets
:
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
dataset
)
if
dataset
==
'retrieval-SfM-120k'
:
download_files
=
[
'{}.pkl'
.
format
(
dataset
),
'{}-whiten.pkl'
.
format
(
dataset
)]
download_eccv2020
=
'{}-val-eccv2020.pkl'
.
format
(
dataset
)
elif
dataset
==
'retrieval-SfM-30k'
:
download_files
=
[
'{}-whiten.pkl'
.
format
(
dataset
)]
download_eccv2020
=
None
if
not
tf
.
io
.
gfile
.
exists
(
dst_dir
):
print
(
'>> Dataset directory does not exist. Creating: {}'
.
format
(
dst_dir
))
tf
.
io
.
gfile
.
mkdir
(
dst_dir
)
for
i
in
range
(
len
(
download_files
)):
src_file
=
os
.
path
.
join
(
src_dir
,
download_files
[
i
])
dst_file
=
os
.
path
.
join
(
dst_dir
,
download_files
[
i
])
if
not
os
.
path
.
isfile
(
dst_file
):
print
(
'>> DB file {} does not exist. Downloading...'
.
format
(
download_files
[
i
]))
os
.
system
(
'wget {} -O {}'
.
format
(
src_file
,
dst_file
))
if
download_eccv2020
:
eccv2020_dst_file
=
os
.
path
.
join
(
dst_dir
,
download_eccv2020
)
if
not
os
.
path
.
isfile
(
eccv2020_dst_file
):
eccv2020_src_dir
=
\
"http://ptak.felk.cvut.cz/personal/toliageo/share/how/dataset/"
eccv2020_dst_file
=
os
.
path
.
join
(
dst_dir
,
download_eccv2020
)
eccv2020_src_file
=
os
.
path
.
join
(
eccv2020_src_dir
,
download_eccv2020
)
os
.
system
(
'wget {} -O {}'
.
format
(
eccv2020_src_file
,
eccv2020_dst_file
))
research/delf/delf/python/datasets/sfm120k/sfm120k.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) module.
[1] From Single Image Query to Detailed 3D Reconstruction.
Johannes L. Schonberger, Filip Radenovic, Ondrej Chum, Jan-Michael Frahm.
The related paper can be found at: https://ieeexplore.ieee.org/document/7299148.
"""
import
os
import
pickle
import
tensorflow
as
tf
from
delf.python.datasets
import
tuples_dataset
from
delf.python.datasets
import
utils
def
id2filename
(
image_id
,
prefix
):
"""Creates a training image path out of its id name.
Used for the image mapping in the Sfm120k datset.
Args:
image_id: String, image id.
prefix: String, root directory where images are saved.
Returns:
filename: String, full image filename.
"""
if
prefix
:
return
os
.
path
.
join
(
prefix
,
image_id
[
-
2
:],
image_id
[
-
4
:
-
2
],
image_id
[
-
6
:
-
4
],
image_id
)
else
:
return
os
.
path
.
join
(
image_id
[
-
2
:],
image_id
[
-
4
:
-
2
],
image_id
[
-
6
:
-
4
],
image_id
)
class
_Sfm120k
(
tuples_dataset
.
TuplesDataset
):
"""Structure-from-Motion (Sfm120k) dataset instance.
The dataset contains the image names lists for training and validation,
the cluster ID (3D model ID) for each image and indices forming
query-positive pairs of images. The images are loaded per epoch and resized
on the fly to the desired dimensionality.
"""
def
__init__
(
self
,
mode
,
data_root
,
imsize
=
None
,
num_negatives
=
5
,
num_queries
=
2000
,
pool_size
=
20000
,
loader
=
utils
.
default_loader
,
eccv2020
=
False
):
"""Structure-from-Motion (Sfm120k) dataset initialization.
Args:
mode: Either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Raises:
ValueError: Raised if `mode` is not one of 'train' or 'val'.
"""
if
mode
not
in
[
'train'
,
'val'
]:
raise
ValueError
(
"`mode` argument should be either 'train' or 'val', passed as a "
"String."
)
# Setting up the paths for the dataset.
if
eccv2020
:
name
=
"retrieval-SfM-120k-val-eccv2020"
else
:
name
=
"retrieval-SfM-120k"
db_root
=
os
.
path
.
join
(
data_root
,
'train/retrieval-SfM-120k'
)
ims_root
=
os
.
path
.
join
(
db_root
,
'ims/'
)
# Loading the dataset db file.
db_filename
=
os
.
path
.
join
(
db_root
,
'{}.pkl'
.
format
(
name
))
with
tf
.
io
.
gfile
.
GFile
(
db_filename
,
'rb'
)
as
f
:
db
=
pickle
.
load
(
f
)[
mode
]
# Setting full paths for the dataset images.
self
.
images
=
[
id2filename
(
img_name
,
None
)
for
img_name
in
db
[
'cids'
]]
# Initializing tuples dataset.
super
().
__init__
(
name
,
mode
,
db_root
,
imsize
,
num_negatives
,
num_queries
,
pool_size
,
loader
,
ims_root
)
def
Sfm120kInfo
(
self
):
"""Metadata for the Sfm120k dataset.
The dataset contains the image names lists for training and
validation, the cluster ID (3D model ID) for each image and indices
forming query-positive pairs of images. The images are loaded per epoch
and resized on the fly to the desired dimensionality.
Returns:
info: dictionary with the dataset parameters.
"""
info
=
{
'train'
:
{
'clusters'
:
91642
,
'pidxs'
:
181697
,
'qidxs'
:
181697
},
'val'
:
{
'clusters'
:
6403
,
'pidxs'
:
1691
,
'qidxs'
:
1691
}}
return
info
def
CreateDataset
(
mode
,
data_root
,
imsize
=
None
,
num_negatives
=
5
,
num_queries
=
2000
,
pool_size
=
20000
,
loader
=
utils
.
default_loader
,
eccv2020
=
False
):
'''Creates Structure-from-Motion (Sfm120k) dataset.
Args:
mode: String, either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Returns:
sfm120k: Sfm120k dataset instance.
'''
return
_Sfm120k
(
mode
,
data_root
,
imsize
,
num_negatives
,
num_queries
,
pool_size
,
loader
,
eccv2020
)
research/delf/delf/python/datasets/sfm120k/sfm120k_test.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Sfm120k dataset module."""
import
tensorflow
as
tf
from
delf.python.datasets.sfm120k
import
sfm120k
class
Sfm120kTest
(
tf
.
test
.
TestCase
):
"""Tests for Sfm120k dataset module."""
def
testId2Filename
(
self
):
"""Tests conversion of image id to full path mapping."""
image_id
=
"29fdc243aeb939388cfdf2d081dc080e"
prefix
=
"train/retrieval-SfM-120k/ims/"
path
=
sfm120k
.
id2filename
(
image_id
,
prefix
)
expected_path
=
"train/retrieval-SfM-120k/ims/0e/08/dc"
\
"/29fdc243aeb939388cfdf2d081dc080e"
self
.
assertEqual
(
path
,
expected_path
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/delf/delf/python/datasets/tuples_dataset.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tuple dataset module.
Based on the Radenovic et al. ECCV16: CNN image retrieval learns from BoW.
For more information refer to https://arxiv.org/abs/1604.02426.
"""
import
os
import
pickle
import
numpy
as
np
import
tensorflow
as
tf
from
delf.python.datasets
import
utils
as
image_loading_utils
from
delf.python.training
import
global_features_utils
from
delf.python.training.model
import
global_model
class
TuplesDataset
():
"""Data loader that loads training and validation tuples.
After initialization, the function create_epoch_tuples() should be called to
create the dataset tuples. After that, the dataset can be iterated through
using next() function.
Tuples are based on Radenovic et al. ECCV16 work: CNN image retrieval
learns from BoW. For more information refer to
https://arxiv.org/abs/1604.02426.
"""
def
__init__
(
self
,
name
,
mode
,
data_root
,
imsize
=
None
,
num_negatives
=
5
,
num_queries
=
2000
,
pool_size
=
20000
,
loader
=
image_loading_utils
.
default_loader
,
ims_root
=
None
):
"""TuplesDataset object initialization.
Args:
name: String, dataset name. I.e. 'retrieval-sfm-120k'.
mode: 'train' or 'val' for training and validation parts of dataset.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side transform.
num_negatives: Integer, number of negative images for a query image in a
training tuple.
num_queries: Integer, number of query images to be processed in one epoch.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are re-mined.
loader: Callable, a function to load an image given its path.
ims_root: String, image root directory.
Raises:
ValueError: If mode is not either 'train' or 'val'.
"""
if
mode
not
in
[
'train'
,
'val'
]:
raise
ValueError
(
"`mode` argument should be either 'train' or 'val', passed as a "
"String."
)
# Loading db.
db_filename
=
os
.
path
.
join
(
data_root
,
'{}.pkl'
.
format
(
name
))
with
tf
.
io
.
gfile
.
GFile
(
db_filename
,
'rb'
)
as
f
:
db
=
pickle
.
load
(
f
)[
mode
]
# Initializing tuples dataset.
self
.
_ims_root
=
data_root
if
ims_root
is
None
else
ims_root
self
.
_name
=
name
self
.
_mode
=
mode
self
.
_imsize
=
imsize
self
.
_clusters
=
db
[
'cluster'
]
self
.
_query_pool
=
db
[
'qidxs'
]
self
.
_positive_pool
=
db
[
'pidxs'
]
if
not
hasattr
(
self
,
'images'
):
self
.
images
=
db
[
'ids'
]
# Size of training subset for an epoch.
self
.
_num_negatives
=
num_negatives
self
.
_num_queries
=
min
(
num_queries
,
len
(
self
.
_query_pool
))
self
.
_pool_size
=
min
(
pool_size
,
len
(
self
.
images
))
self
.
_qidxs
=
None
self
.
_pidxs
=
None
self
.
_nidxs
=
None
self
.
_loader
=
loader
self
.
_print_freq
=
10
# Indexer for the iterator.
self
.
_n
=
0
def
__iter__
(
self
):
"""Function for making TupleDataset an iterator.
Returns:
iter: The iterator object itself (TupleDataset).
"""
return
self
def
__next__
(
self
):
"""Function for making TupleDataset an iterator.
Returns:
next: The next item in the sequence (next dataset image tuple).
"""
if
self
.
_n
<
len
(
self
.
_qidxs
):
result
=
self
.
__getitem__
(
self
.
_n
)
self
.
_n
+=
1
return
result
else
:
raise
StopIteration
def
_img_names_to_full_path
(
self
,
image_list
):
"""Converts list of image names to the list of full paths to the images.
Args:
image_list: Image names, either a list or a single image path.
Returns:
image_full_paths: List of full paths to the images.
"""
if
not
isinstance
(
image_list
,
list
):
return
os
.
path
.
join
(
self
.
_ims_root
,
image_list
)
return
[
os
.
path
.
join
(
self
.
_ims_root
,
img_name
)
for
img_name
in
image_list
]
def
__getitem__
(
self
,
index
):
"""Called to load an image tuple at the given `index`.
Args:
index: Integer, index.
Returns:
output: Tuple [q,p,n1,...,nN, target], loaded 'train'/'val' tuple at
index of qidxs. `q` is the query image tensor, `p` is the
corresponding positive image tensor, `n1`,...,`nN` are the negatives
associated with the query. `target` is a tensor (with the shape [2+N])
of integer labels corresponding to the tuple list: query (-1),
positive (1), negative (0).
Raises:
ValueError: Raised if the query indexes list `qidxs` is empty.
"""
if
self
.
__len__
()
==
0
:
raise
ValueError
(
"List `qidxs` is empty. Run `dataset.create_epoch_tuples(net)` "
"method to create subset for `train`/`val`."
)
output
=
[]
# Query image.
output
.
append
(
self
.
_loader
(
self
.
_img_names_to_full_path
(
self
.
images
[
self
.
_qidxs
[
index
]]),
self
.
_imsize
))
# Positive image.
output
.
append
(
self
.
_loader
(
self
.
_img_names_to_full_path
(
self
.
images
[
self
.
_pidxs
[
index
]]),
self
.
_imsize
))
# Negative images.
for
nidx
in
self
.
_nidxs
[
index
]:
output
.
append
(
self
.
_loader
(
self
.
_img_names_to_full_path
(
self
.
images
[
nidx
]),
self
.
_imsize
))
# Labels for the query (-1), positive (1), negative (0) images in the tuple.
target
=
tf
.
convert_to_tensor
([
-
1
,
1
]
+
[
0
]
*
self
.
_num_negatives
)
output
.
append
(
target
)
return
tuple
(
output
)
def
__len__
(
self
):
"""Called to implement the built-in function len().
Returns:
len: Integer, number of query images.
"""
if
self
.
_qidxs
is
None
:
return
0
return
len
(
self
.
_qidxs
)
def
__repr__
(
self
):
"""Metadata for the TupleDataset.
Returns:
meta: String, containing TupleDataset meta.
"""
fmt_str
=
self
.
__class__
.
__name__
+
'
\n
'
fmt_str
+=
'
\t
Name and mode: {} {}
\n
'
.
format
(
self
.
_name
,
self
.
_mode
)
fmt_str
+=
'
\t
Number of images: {}
\n
'
.
format
(
len
(
self
.
images
))
fmt_str
+=
'
\t
Number of training tuples: {}
\n
'
.
format
(
len
(
self
.
_query_pool
))
fmt_str
+=
'
\t
Number of negatives per tuple: {}
\n
'
.
format
(
self
.
_num_negatives
)
fmt_str
+=
'
\t
Number of tuples processed in an epoch: {}
\n
'
.
format
(
self
.
_num_queries
)
fmt_str
+=
'
\t
Pool size for negative remining: {}
\n
'
.
format
(
self
.
_pool_size
)
return
fmt_str
def
create_epoch_tuples
(
self
,
net
):
"""Creates epoch tuples with the hard-negative re-mining.
Negative examples are selected from clusters different than the cluster
of the query image, as the clusters are ideally non-overlaping. For
every query image we choose hard-negatives, that is, non-matching images
with the most similar descriptor. Hard-negatives depend on the current
CNN parameters. K-nearest neighbors from all non-matching images are
selected. Query images are selected randomly. Positives examples are
fixed for the related query image during the whole training process.
Args:
net: Model, network to be used for negative re-mining.
Raises:
ValueError: If the pool_size is smaller than the number of negative
images per tuple.
Returns:
avg_l2: Float, average negative L2-distance.
"""
self
.
_n
=
0
if
self
.
_num_negatives
<
self
.
_pool_size
:
raise
ValueError
(
"Unable to create epoch tuples. Negative pool_size "
"should be larger than the number of negative images "
"per tuple."
)
global_features_utils
.
debug_and_log
(
'>> Creating tuples for an epoch of {}-{}...'
.
format
(
self
.
_name
,
self
.
_mode
),
True
)
global_features_utils
.
debug_and_log
(
">> Used network: "
,
True
)
global_features_utils
.
debug_and_log
(
net
.
meta_repr
(),
True
)
## Selecting queries.
# Draw `num_queries` random queries for the tuples.
idx_list
=
np
.
arange
(
len
(
self
.
_query_pool
))
np
.
random
.
shuffle
(
idx_list
)
idxs2query_pool
=
idx_list
[:
self
.
_num_queries
]
self
.
_qidxs
=
[
self
.
_query_pool
[
i
]
for
i
in
idxs2query_pool
]
## Selecting positive pairs.
# Positives examples are fixed for each query during the whole training
# process.
self
.
_pidxs
=
[
self
.
_positive_pool
[
i
]
for
i
in
idxs2query_pool
]
## Selecting negative pairs.
# If `num_negatives` = 0 create dummy nidxs.
# Useful when only positives used for training.
if
self
.
_num_negatives
==
0
:
self
.
_nidxs
=
[[]
for
_
in
range
(
len
(
self
.
_qidxs
))]
return
0
# Draw pool_size random images for pool of negatives images.
neg_idx_list
=
np
.
arange
(
len
(
self
.
images
))
np
.
random
.
shuffle
(
neg_idx_list
)
neg_images_idxs
=
neg_idx_list
[:
self
.
_pool_size
]
global_features_utils
.
debug_and_log
(
'>> Extracting descriptors for query images...'
,
debug
=
True
)
img_list
=
self
.
_img_names_to_full_path
([
self
.
images
[
i
]
for
i
in
self
.
_qidxs
])
qvecs
=
global_model
.
extract_global_descriptors_from_list
(
net
,
images
=
img_list
,
image_size
=
self
.
_imsize
,
print_freq
=
self
.
_print_freq
)
global_features_utils
.
debug_and_log
(
'>> Extracting descriptors for negative pool...'
,
debug
=
True
)
poolvecs
=
global_model
.
extract_global_descriptors_from_list
(
net
,
images
=
self
.
_img_names_to_full_path
([
self
.
images
[
i
]
for
i
in
neg_images_idxs
]),
image_size
=
self
.
_imsize
,
print_freq
=
self
.
_print_freq
)
global_features_utils
.
debug_and_log
(
'>> Searching for hard negatives...'
,
debug
=
True
)
# Compute dot product scores and ranks.
scores
=
tf
.
linalg
.
matmul
(
poolvecs
,
qvecs
,
transpose_a
=
True
)
ranks
=
tf
.
argsort
(
scores
,
axis
=
0
,
direction
=
'DESCENDING'
)
sum_ndist
=
0.
n_ndist
=
0.
# Selection of negative examples.
self
.
_nidxs
=
[]
for
q
,
qidx
in
enumerate
(
self
.
_qidxs
):
# We are not using the query cluster, those images are potentially
# positive.
qcluster
=
self
.
_clusters
[
qidx
]
clusters
=
[
qcluster
]
nidxs
=
[]
rank
=
0
while
len
(
nidxs
)
<
self
.
_num_negatives
:
if
rank
>=
tf
.
shape
(
ranks
)[
0
]:
raise
ValueError
(
"Unable to create epoch tuples. Number of required "
"negative images is larger than the number of "
"clusters in the dataset."
)
potential
=
neg_images_idxs
[
ranks
[
rank
,
q
]]
# Take at most one image from the same cluster.
if
not
self
.
_clusters
[
potential
]
in
clusters
:
nidxs
.
append
(
potential
)
clusters
.
append
(
self
.
_clusters
[
potential
])
dist
=
tf
.
norm
(
qvecs
[:,
q
]
-
poolvecs
[:,
ranks
[
rank
,
q
]],
axis
=
0
).
numpy
()
sum_ndist
+=
dist
n_ndist
+=
1
rank
+=
1
self
.
_nidxs
.
append
(
nidxs
)
global_features_utils
.
debug_and_log
(
'>> Average negative l2-distance: {:.2f}'
.
format
(
sum_ndist
/
n_ndist
))
# Return average negative L2-distance.
return
sum_ndist
/
n_ndist
research/delf/delf/python/datasets/tuples_dataset_test.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"Tests for the tuples dataset module."
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
flags
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
import
pickle
from
delf.python.datasets
import
tuples_dataset
from
delf.python.training.model
import
global_model
FLAGS
=
flags
.
FLAGS
class
TuplesDatasetTest
(
tf
.
test
.
TestCase
):
"""Tests for tuples dataset module."""
def
testCreateEpochTuples
(
self
):
"""Tests epoch tuple creation."""
# Create a tuples dataset instance.
name
=
'test_dataset'
num_queries
=
1
pool_size
=
5
num_negatives
=
2
# Create a ground truth .pkl file.
gnd
=
{
'train'
:
{
'ids'
:
[
str
(
i
)
+
'.png'
for
i
in
range
(
2
*
num_queries
+
pool_size
)],
'cluster'
:
[
0
,
0
,
1
,
2
,
3
,
4
,
5
],
'qidxs'
:
[
0
],
'pidxs'
:
[
1
]}}
gnd_name
=
name
+
'.pkl'
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
gnd_name
),
'wb'
)
as
gnd_file
:
pickle
.
dump
(
gnd
,
gnd_file
)
# Create random images for the dataset.
for
i
in
range
(
2
*
num_queries
+
pool_size
):
dummy_image
=
np
.
random
.
rand
(
1024
,
750
,
3
)
*
255
img_out
=
Image
.
fromarray
(
dummy_image
.
astype
(
'uint8'
)).
convert
(
'RGB'
)
filename
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'{}.png'
.
format
(
i
))
img_out
.
save
(
filename
)
dataset
=
tuples_dataset
.
TuplesDataset
(
name
=
name
,
data_root
=
FLAGS
.
test_tmpdir
,
mode
=
'train'
,
imsize
=
1024
,
num_negatives
=
num_negatives
,
num_queries
=
num_queries
,
pool_size
=
pool_size
)
# Assert that initially no negative images are set.
self
.
assertIsNone
(
dataset
.
_nidxs
)
# Initialize a network for negative re-mining.
model_params
=
{
'architecture'
:
'ResNet101'
,
'pooling'
:
'gem'
,
'whitening'
:
False
,
'pretrained'
:
True
}
model
=
global_model
.
GlobalFeatureNet
(
**
model_params
)
avg_neg_distance
=
dataset
.
create_epoch_tuples
(
model
)
# Check that an appropriate number of negative images has been chosen per
# query.
self
.
assertAllEqual
(
tf
.
shape
(
dataset
.
_nidxs
),
[
num_queries
,
num_negatives
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/delf/delf/python/training/global_features/__init__.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global model training."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
research/delf/delf/python/training/global_features/train.py
0 → 100644
View file @
78c43ef1
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training script for Global Features model."""
import
math
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_addons
as
tfa
from
delf.python.datasets.sfm120k
import
dataset_download
from
delf.python.datasets.sfm120k
import
sfm120k
from
delf.python.training
import
global_features_utils
from
delf.python.training
import
tensorboard_utils
from
delf.python.training.global_features
import
train_utils
from
delf.python.training.losses
import
ranking_losses
from
delf.python.training.model
import
global_model
_LOSS_NAMES
=
[
'contrastive'
,
'triplet'
]
_MODEL_NAMES
=
global_features_utils
.
get_standard_keras_models
()
_OPTIMIZER_NAMES
=
[
'sgd'
,
'adam'
]
_POOL_NAMES
=
[
'mac'
,
'spoc'
,
'gem'
]
_PRECOMPUTE_WHITEN_NAMES
=
[
'retrieval-SfM-30k'
,
'retrieval-SfM-120k'
]
_TEST_DATASET_NAMES
=
[
'roxford5k'
,
'rparis6k'
]
_TRAINING_DATASET_NAMES
=
[
'retrieval-SfM-120k'
]
_VALIDATION_TYPES
=
[
'standard'
,
'eccv2020'
]
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_boolean
(
'debug'
,
False
,
'Debug mode.'
)
# Export directory, training and val datasets, test datasets.
flags
.
DEFINE_string
(
'data_root'
,
"data"
,
'Absolute path to the folder containing training data.'
)
flags
.
DEFINE_string
(
'directory'
,
"data"
,
'Destination where trained network should be saved.'
)
flags
.
DEFINE_enum
(
'training_dataset'
,
'retrieval-SfM-120k'
,
_TRAINING_DATASET_NAMES
,
'Training dataset: '
+
' | '
.
join
(
_TRAINING_DATASET_NAMES
)
+
'.'
)
flags
.
DEFINE_enum
(
'validation_type'
,
None
,
_VALIDATION_TYPES
,
'Type of the evaluation to use. Either `None`, `standard` '
'or `eccv2020`.'
)
flags
.
DEFINE_list
(
'test_datasets'
,
'roxford5k,rparis6k'
,
'Comma separated list of test datasets: '
+
' | '
.
join
(
_TEST_DATASET_NAMES
)
+
'.'
)
flags
.
DEFINE_enum
(
'precompute_whitening'
,
None
,
_PRECOMPUTE_WHITEN_NAMES
,
'Dataset used to learn whitening: '
+
' | '
.
join
(
_PRECOMPUTE_WHITEN_NAMES
)
+
'.'
)
flags
.
DEFINE_integer
(
'test_freq'
,
5
,
'Run test evaluation every N epochs.'
)
flags
.
DEFINE_list
(
'multiscale'
,
[
1.
],
'Use multiscale vectors for testing, '
+
' examples: 1 | 1,1/2**(1/2),1/2 | 1,2**(1/2),1/2**(1/2)]. '
'Pass as a string of comma separated values.'
)
# Network architecture and initialization options.
flags
.
DEFINE_enum
(
'arch'
,
'ResNet101'
,
_MODEL_NAMES
,
'Model architecture: '
+
' | '
.
join
(
_MODEL_NAMES
)
+
'.'
)
flags
.
DEFINE_enum
(
'pool'
,
'gem'
,
_POOL_NAMES
,
'Pooling options: '
+
' | '
.
join
(
_POOL_NAMES
)
+
'.'
)
flags
.
DEFINE_bool
(
'whitening'
,
False
,
'Whether to train model with learnable whitening ('
'linear layer) after the pooling.'
)
flags
.
DEFINE_bool
(
'pretrained'
,
True
,
'Whether to initialize model with random weights ('
'default: pretrained on imagenet).'
)
flags
.
DEFINE_enum
(
'loss'
,
'contrastive'
,
_LOSS_NAMES
,
'Training loss options: '
+
' | '
.
join
(
_LOSS_NAMES
)
+
'.'
)
flags
.
DEFINE_float
(
'loss_margin'
,
0.7
,
'Loss margin.'
)
# train/val options specific for image retrieval learning.
flags
.
DEFINE_integer
(
'image_size'
,
1024
,
'Maximum size of longer image side used for training.'
)
flags
.
DEFINE_integer
(
'neg_num'
,
5
,
'Number of negative images per train/val '
'tuple.'
)
flags
.
DEFINE_integer
(
'query_size'
,
2000
,
'Number of queries randomly drawn per one training epoch.'
)
flags
.
DEFINE_integer
(
'pool_size'
,
20000
,
'Size of the pool for hard negative mining.'
)
# Standard training/validation options.
flags
.
DEFINE_string
(
'gpu_id'
,
'0'
,
'GPU id used for training.'
)
flags
.
DEFINE_integer
(
'epochs'
,
100
,
'Number of total epochs to run.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
5
,
'Number of (q,p,n1,...,nN) tuples in a mini-batch.'
)
flags
.
DEFINE_integer
(
'update_every'
,
1
,
'Update model weights every N batches, used to handle '
'relatively large batches, batch_size effectively '
'becomes update_every `x` batch_size.'
)
flags
.
DEFINE_enum
(
'optimizer'
,
'adam'
,
_OPTIMIZER_NAMES
,
'Optimizer options: '
+
' | '
.
join
(
_OPTIMIZER_NAMES
)
+
'.'
)
flags
.
DEFINE_float
(
'lr'
,
1e-6
,
'Initial learning rate.'
)
flags
.
DEFINE_float
(
'momentum'
,
0.9
,
'Momentum.'
)
flags
.
DEFINE_float
(
'weight_decay'
,
1e-6
,
'Weight decay.'
)
flags
.
DEFINE_bool
(
'resume'
,
False
,
'Whether to start from the latest checkpoint in the logdir.'
)
flags
.
DEFINE_bool
(
'launch_tensorboard'
,
False
,
'Whether to launch tensorboard.'
)
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
RuntimeError
(
'Too many command-line arguments.'
)
# Manually check if there are unknown test datasets and if the dataset
# ground truth files are downloaded.
for
dataset
in
FLAGS
.
test_datasets
:
if
dataset
not
in
_TEST_DATASET_NAMES
:
raise
ValueError
(
'Unsupported or unknown test dataset: {}.'
.
format
(
dataset
))
test_data_config
=
os
.
path
.
join
(
FLAGS
.
data_root
,
'gnd_{}.pkl'
.
format
(
dataset
))
if
not
tf
.
io
.
gfile
.
exists
(
test_data_config
):
raise
ValueError
(
'{} ground truth file at {} not found. Please download it '
'according to '
'the DELG instructions.'
.
format
(
dataset
,
FLAGS
.
data_root
))
# Check if train dataset is downloaded and download it if not found.
dataset_download
.
download_train
(
FLAGS
.
data_root
)
# Creating model export directory if it does not exist.
model_directory
=
global_features_utils
.
create_model_directory
(
FLAGS
.
training_dataset
,
FLAGS
.
arch
,
FLAGS
.
pool
,
FLAGS
.
whitening
,
FLAGS
.
pretrained
,
FLAGS
.
loss
,
FLAGS
.
loss_margin
,
FLAGS
.
optimizer
,
FLAGS
.
lr
,
FLAGS
.
weight_decay
,
FLAGS
.
neg_num
,
FLAGS
.
query_size
,
FLAGS
.
pool_size
,
FLAGS
.
batch_size
,
FLAGS
.
update_every
,
FLAGS
.
image_size
,
FLAGS
.
directory
)
# Setting up logging directory, same as where the model is stored.
logging
.
get_absl_handler
().
use_absl_log_file
(
'absl_logging'
,
model_directory
)
# Set cuda visible device.
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
FLAGS
.
gpu_id
global_features_utils
.
debug_and_log
(
'>> Num GPUs Available: {}'
.
format
(
len
(
tf
.
config
.
experimental
.
list_physical_devices
(
'GPU'
))),
FLAGS
.
debug
)
# Set random seeds.
tf
.
random
.
set_seed
(
0
)
np
.
random
.
seed
(
0
)
# Initialize the model.
if
FLAGS
.
pretrained
:
global_features_utils
.
debug_and_log
(
'>> Using pre-trained model
\'
{}
\'
'
.
format
(
FLAGS
.
arch
))
else
:
global_features_utils
.
debug_and_log
(
'>> Using model from scratch (random weights)
\'
{}
\'
.'
.
format
(
FLAGS
.
arch
))
model_params
=
{
'architecture'
:
FLAGS
.
arch
,
'pooling'
:
FLAGS
.
pool
,
'whitening'
:
FLAGS
.
whitening
,
'pretrained'
:
FLAGS
.
pretrained
,
'data_root'
:
FLAGS
.
data_root
}
model
=
global_model
.
GlobalFeatureNet
(
**
model_params
)
# Freeze running mean and std in batch normalization layers.
# We do training one image at a time to improve memory requirements of
# the network; therefore, the computed statistics would not be per a
# batch. Instead, we choose freezing - setting the parameters of all
# batch norm layers in the network to non-trainable (i.e., using original
# imagenet statistics).
for
layer
in
model
.
feature_extractor
.
layers
:
if
isinstance
(
layer
,
tf
.
keras
.
layers
.
BatchNormalization
):
layer
.
trainable
=
False
global_features_utils
.
debug_and_log
(
'>> Network initialized.'
)
global_features_utils
.
debug_and_log
(
'>> Loss: {}.'
.
format
(
FLAGS
.
loss
))
# Define the loss function.
if
FLAGS
.
loss
==
'contrastive'
:
criterion
=
ranking_losses
.
ContrastiveLoss
(
margin
=
FLAGS
.
loss_margin
)
elif
FLAGS
.
loss
==
'triplet'
:
criterion
=
ranking_losses
.
TripletLoss
(
margin
=
FLAGS
.
loss_margin
)
else
:
raise
ValueError
(
'Loss {} not available.'
.
format
(
FLAGS
.
loss
))
# Defining parameters for the training.
# When pre-computing whitening, we run evaluation before the network training
# and the `start_epoch` is set to 0. In other cases, we start from epoch 1.
start_epoch
=
1
exp_decay
=
math
.
exp
(
-
0.01
)
decay_steps
=
FLAGS
.
query_size
/
FLAGS
.
batch_size
# Define learning rate decay schedule.
lr_scheduler
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
FLAGS
.
lr
,
decay_steps
=
decay_steps
,
decay_rate
=
exp_decay
)
# Define the optimizer.
if
FLAGS
.
optimizer
==
'sgd'
:
opt
=
tfa
.
optimizers
.
extend_with_decoupled_weight_decay
(
tf
.
keras
.
optimizers
.
SGD
)
optimizer
=
opt
(
weight_decay
=
FLAGS
.
weight_decay
,
learning_rate
=
lr_scheduler
,
momentum
=
FLAGS
.
momentum
)
elif
FLAGS
.
optimizer
==
'adam'
:
opt
=
tfa
.
optimizers
.
extend_with_decoupled_weight_decay
(
tf
.
keras
.
optimizers
.
Adam
)
optimizer
=
opt
(
weight_decay
=
FLAGS
.
weight_decay
,
learning_rate
=
lr_scheduler
)
else
:
raise
ValueError
(
'Optimizer {} not available.'
.
format
(
FLAGS
.
optimizer
))
# Initializing logging.
writer
=
tf
.
summary
.
create_file_writer
(
model_directory
)
tf
.
summary
.
experimental
.
set_step
(
1
)
# Setting up the checkpoint manager.
checkpoint
=
tf
.
train
.
Checkpoint
(
optimizer
=
optimizer
,
model
=
model
)
manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
model_directory
,
max_to_keep
=
10
,
keep_checkpoint_every_n_hours
=
3
)
if
FLAGS
.
resume
:
# Restores the checkpoint, if existing.
global_features_utils
.
debug_and_log
(
'>> Continuing from a checkpoint.'
)
checkpoint
.
restore
(
manager
.
latest_checkpoint
)
# Launching tensorboard if required.
if
FLAGS
.
launch_tensorboard
:
tensorboard
=
tf
.
keras
.
callbacks
.
TensorBoard
(
model_directory
)
tensorboard
.
set_model
(
model
=
model
)
tensorboard_utils
.
launch_tensorboard
(
log_dir
=
model_directory
)
# Log flags used.
global_features_utils
.
debug_and_log
(
'>> Running training script with:'
)
global_features_utils
.
debug_and_log
(
'>> logdir = {}'
.
format
(
model_directory
))
if
FLAGS
.
training_dataset
.
startswith
(
'retrieval-SfM-120k'
):
train_dataset
=
sfm120k
.
CreateDataset
(
data_root
=
FLAGS
.
data_root
,
mode
=
'train'
,
imsize
=
FLAGS
.
image_size
,
num_negatives
=
FLAGS
.
neg_num
,
num_queries
=
FLAGS
.
query_size
,
pool_size
=
FLAGS
.
pool_size
)
if
FLAGS
.
validation_type
is
not
None
:
val_dataset
=
sfm120k
.
CreateDataset
(
data_root
=
FLAGS
.
data_root
,
mode
=
'val'
,
imsize
=
FLAGS
.
image_size
,
num_negatives
=
FLAGS
.
neg_num
,
num_queries
=
float
(
'Inf'
),
pool_size
=
float
(
'Inf'
),
eccv2020
=
True
if
FLAGS
.
validation_type
==
'eccv2020'
else
False
)
train_dataset_output_types
=
[
tf
.
float32
for
i
in
range
(
2
+
FLAGS
.
neg_num
)]
train_dataset_output_types
.
append
(
tf
.
int32
)
global_features_utils
.
debug_and_log
(
'>> Training the {} network'
.
format
(
model_directory
))
global_features_utils
.
debug_and_log
(
'>> GPU ids: {}'
.
format
(
FLAGS
.
gpu_id
))
with
writer
.
as_default
():
# Precompute whitening if needed.
if
FLAGS
.
precompute_whitening
is
not
None
:
epoch
=
0
train_utils
.
test_retrieval
(
FLAGS
.
test_datasets
,
model
,
writer
=
writer
,
epoch
=
epoch
,
model_directory
=
model_directory
,
precompute_whitening
=
FLAGS
.
precompute_whitening
,
data_root
=
FLAGS
.
data_root
,
multiscale
=
FLAGS
.
multiscale
)
for
epoch
in
range
(
start_epoch
,
FLAGS
.
epochs
+
1
):
# Set manual seeds per epoch.
np
.
random
.
seed
(
epoch
)
tf
.
random
.
set_seed
(
epoch
)
# Find hard-negatives.
# While hard-positive examples are fixed during the whole training
# process and are randomly chosen from every epoch; hard-negatives
# depend on the current CNN parameters and are re-mined once per epoch.
avg_neg_distance
=
train_dataset
.
create_epoch_tuples
(
model
)
def
_train_gen
():
return
(
inst
for
inst
in
train_dataset
)
train_loader
=
tf
.
data
.
Dataset
.
from_generator
(
_train_gen
,
output_types
=
tuple
(
train_dataset_output_types
))
loss
=
train_utils
.
train_val_one_epoch
(
loader
=
iter
(
train_loader
),
model
=
model
,
criterion
=
criterion
,
optimizer
=
optimizer
,
epoch
=
epoch
,
batch_size
=
FLAGS
.
batch_size
,
query_size
=
FLAGS
.
query_size
,
neg_num
=
FLAGS
.
neg_num
,
update_every
=
FLAGS
.
update_every
,
debug
=
FLAGS
.
debug
)
# Write a scalar summary.
tf
.
summary
.
scalar
(
'train_epoch_loss'
,
loss
,
step
=
epoch
)
# Forces summary writer to send any buffered data to storage.
writer
.
flush
()
# Evaluate on validation set.
if
FLAGS
.
validation_type
is
not
None
and
(
epoch
%
FLAGS
.
test_freq
==
0
or
epoch
==
1
):
avg_neg_distance
=
val_dataset
.
create_epoch_tuples
(
model
,
model_directory
)
def
_val_gen
():
return
(
inst
for
inst
in
val_dataset
)
val_loader
=
tf
.
data
.
Dataset
.
from_generator
(
_val_gen
,
output_types
=
tuple
(
train_dataset_output_types
))
loss
=
train_utils
.
train_val_one_epoch
(
loader
=
iter
(
val_loader
),
model
=
model
,
criterion
=
criterion
,
optimizer
=
None
,
epoch
=
epoch
,
train
=
False
,
batch_size
=
FLAGS
.
batch_size
,
query_size
=
FLAGS
.
query_size
,
neg_num
=
FLAGS
.
neg_num
,
update_every
=
FLAGS
.
update_every
,
debug
=
FLAGS
.
debug
)
tf
.
summary
.
scalar
(
'val_epoch_loss'
,
loss
,
step
=
epoch
)
writer
.
flush
()
# Evaluate on test datasets every test_freq epochs.
if
epoch
==
1
or
epoch
%
FLAGS
.
test_freq
==
0
:
train_utils
.
test_retrieval
(
FLAGS
.
test_datasets
,
model
,
writer
=
writer
,
epoch
=
epoch
,
model_directory
=
model_directory
,
precompute_whitening
=
FLAGS
.
precompute_whitening
,
data_root
=
FLAGS
.
data_root
,
multiscale
=
FLAGS
.
multiscale
)
# Saving checkpoints and model weights.
try
:
save_path
=
manager
.
save
(
checkpoint_number
=
epoch
)
global_features_utils
.
debug_and_log
(
'Saved ({}) at {}'
.
format
(
epoch
,
save_path
))
filename
=
os
.
path
.
join
(
model_directory
,
'checkpoint_epoch_{}.h5'
.
format
(
epoch
))
model
.
save_weights
(
filename
,
save_format
=
'h5'
)
global_features_utils
.
debug_and_log
(
'Saved weights ({}) at {}'
.
format
(
epoch
,
filename
))
except
Exception
as
ex
:
global_features_utils
.
debug_and_log
(
'Could not save checkpoint: {}'
.
format
(
ex
))
if
__name__
==
'__main__'
:
app
.
run
(
main
)
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment