Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5ffcc5b6
Unverified
Commit
5ffcc5b6
authored
Jul 21, 2021
by
Anirudh Vegesana
Committed by
GitHub
Jul 21, 2021
Browse files
Merge branch 'purdue-yolo' into detection_generator_pr
parents
0b81a843
76e0c014
Changes
190
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
828 additions
and
49 deletions
+828
-49
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+24
-3
official/vision/beta/serving/export_saved_model.py
official/vision/beta/serving/export_saved_model.py
+6
-2
official/vision/beta/serving/export_saved_model_lib.py
official/vision/beta/serving/export_saved_model_lib.py
+12
-3
official/vision/beta/serving/image_classification.py
official/vision/beta/serving/image_classification.py
+1
-1
official/vision/beta/serving/video_classification.py
official/vision/beta/serving/video_classification.py
+191
-0
official/vision/beta/serving/video_classification_test.py
official/vision/beta/serving/video_classification_test.py
+114
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+2
-6
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+2
-6
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+2
-6
official/vision/beta/train.py
official/vision/beta/train.py
+1
-0
official/vision/beta/train_spatial_partitioning.py
official/vision/beta/train_spatial_partitioning.py
+22
-6
official/vision/image_classification/README.md
official/vision/image_classification/README.md
+3
-0
official/vision/keras_cv/ops/iou_similarity.py
official/vision/keras_cv/ops/iou_similarity.py
+3
-0
orbit/actions/export_saved_model.py
orbit/actions/export_saved_model.py
+18
-16
orbit/actions/export_saved_model_test.py
orbit/actions/export_saved_model_test.py
+17
-0
research/delf/delf/python/datasets/generic_dataset.py
research/delf/delf/python/datasets/generic_dataset.py
+81
-0
research/delf/delf/python/datasets/generic_dataset_test.py
research/delf/delf/python/datasets/generic_dataset_test.py
+60
-0
research/delf/delf/python/datasets/sfm120k/__init__.py
research/delf/delf/python/datasets/sfm120k/__init__.py
+23
-0
research/delf/delf/python/datasets/sfm120k/dataset_download.py
...rch/delf/delf/python/datasets/sfm120k/dataset_download.py
+103
-0
research/delf/delf/python/datasets/sfm120k/sfm120k.py
research/delf/delf/python/datasets/sfm120k/sfm120k.py
+143
-0
No files found.
official/vision/beta/serving/detection.py
View file @
5ffcc5b6
...
...
@@ -20,6 +20,7 @@ import tensorflow as tf
from
official.vision.beta
import
configs
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.serving
import
export_base
...
...
@@ -130,6 +131,28 @@ class DetectionModule(export_base.ExportModule):
training
=
False
)
if
self
.
params
.
task
.
model
.
detection_generator
.
apply_nms
:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if
isinstance
(
self
.
params
.
task
.
model
,
configs
.
retinanet
.
RetinaNet
):
export_config
=
self
.
params
.
task
.
export_config
# Normalize detection box coordinates to [0, 1].
if
export_config
.
output_normalized_coordinates
:
detection_boxes
=
(
detections
[
'detection_boxes'
]
/
tf
.
tile
(
image_info
[:,
2
:
3
,
:],
[
1
,
1
,
2
]))
detections
[
'detection_boxes'
]
=
box_ops
.
normalize_boxes
(
detection_boxes
,
image_info
[:,
0
:
1
,
:])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if
export_config
.
cast_num_detections_to_float
:
detections
[
'num_detections'
]
=
tf
.
cast
(
detections
[
'num_detections'
],
dtype
=
tf
.
float32
)
if
export_config
.
cast_detection_classes_to_float
:
detections
[
'detection_classes'
]
=
tf
.
cast
(
detections
[
'detection_classes'
],
dtype
=
tf
.
float32
)
final_outputs
=
{
'detection_boxes'
:
detections
[
'detection_boxes'
],
'detection_scores'
:
detections
[
'detection_scores'
],
...
...
@@ -139,9 +162,7 @@ class DetectionModule(export_base.ExportModule):
else
:
final_outputs
=
{
'decoded_boxes'
:
detections
[
'decoded_boxes'
],
'decoded_box_scores'
:
detections
[
'decoded_box_scores'
],
'cls_outputs'
:
detections
[
'cls_outputs'
],
'box_outputs'
:
detections
[
'box_outputs'
]
'decoded_box_scores'
:
detections
[
'decoded_box_scores'
]
}
if
'detection_masks'
in
detections
.
keys
():
...
...
official/vision/beta/serving/export_saved_model.py
View file @
5ffcc5b6
...
...
@@ -73,6 +73,10 @@ flags.DEFINE_string(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
flags
.
DEFINE_string
(
'export_checkpoint_subdir'
,
'checkpoint'
,
'The subdirectory for checkpoints.'
)
flags
.
DEFINE_string
(
'export_saved_model_subdir'
,
'saved_model'
,
'The subdirectory for saved model.'
)
def
main
(
_
):
...
...
@@ -95,8 +99,8 @@ def main(_):
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
,
export_checkpoint_subdir
=
'
checkpoint
'
,
export_saved_model_subdir
=
'
saved_model
'
)
export_checkpoint_subdir
=
FLAGS
.
export_
checkpoint
_subdir
,
export_saved_model_subdir
=
FLAGS
.
export_
saved_model
_subdir
)
if
__name__
==
'__main__'
:
...
...
official/vision/beta/serving/export_saved_model_lib.py
View file @
5ffcc5b6
...
...
@@ -27,6 +27,7 @@ from official.vision.beta import configs
from
official.vision.beta.serving
import
detection
from
official.vision.beta.serving
import
image_classification
from
official.vision.beta.serving
import
semantic_segmentation
from
official.vision.beta.serving
import
video_classification
def
export_inference_graph
(
...
...
@@ -68,7 +69,7 @@ def export_inference_graph(
output_checkpoint_directory
=
os
.
path
.
join
(
export_dir
,
export_checkpoint_subdir
)
else
:
output_checkpoint_directory
=
export_dir
output_checkpoint_directory
=
None
if
export_saved_model_subdir
:
output_saved_model_directory
=
os
.
path
.
join
(
...
...
@@ -99,6 +100,13 @@ def export_inference_graph(
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
elif
isinstance
(
params
.
task
,
configs
.
video_classification
.
VideoClassificationTask
):
export_module
=
video_classification
.
VideoClassificationModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
type
(
params
.
task
)))
...
...
@@ -111,6 +119,7 @@ def export_inference_graph(
timestamped
=
False
,
save_options
=
save_options
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
if
output_checkpoint_directory
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
train_utils
.
serialize_config
(
params
,
export_dir
)
official/vision/beta/serving/image_classification.py
View file @
5ffcc5b6
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# Lint as: python3
"""
Detec
tion input and model functions for serving/inference."""
"""
Image classifica
tion input and model functions for serving/inference."""
import
tensorflow
as
tf
...
...
official/vision/beta/serving/video_classification.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification input and model functions for serving/inference."""
from
typing
import
Mapping
,
Dict
,
Text
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.serving
import
export_base
from
official.vision.beta.tasks
import
video_classification
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
VideoClassificationModule
(
export_base
.
ExportModule
):
"""Video classification Module."""
def
_build_model
(
self
):
input_params
=
self
.
params
.
task
.
train_data
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_stride
=
input_params
.
temporal_stride
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
self
.
_output_audio
=
input_params
.
output_audio
task
=
video_classification
.
VideoClassificationTask
(
self
.
params
.
task
)
return
task
.
build_model
()
def
_decode_tf_example
(
self
,
encoded_inputs
:
tf
.
Tensor
):
sequence_description
=
{
# Each image is a string encoding JPEG.
video_input
.
IMAGE_KEY
:
tf
.
io
.
FixedLenSequenceFeature
((),
tf
.
string
),
}
if
self
.
_output_audio
:
sequence_description
[
self
.
_params
.
task
.
validation_data
.
audio_feature
]
=
(
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
))
_
,
decoded_tensors
=
tf
.
io
.
parse_single_sequence_example
(
encoded_inputs
,
{},
sequence_description
)
for
key
,
value
in
decoded_tensors
.
items
():
if
isinstance
(
value
,
tf
.
SparseTensor
):
decoded_tensors
[
key
]
=
tf
.
sparse
.
to_dense
(
value
)
return
decoded_tensors
def
_preprocess_image
(
self
,
image
):
image
=
video_input
.
process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
1
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
1
)
image
=
tf
.
cast
(
image
,
tf
.
float32
)
# Use config.
features
=
{
'image'
:
image
}
return
features
def
_preprocess_audio
(
self
,
audio
):
features
=
{}
audio
=
tf
.
cast
(
audio
,
dtype
=
tf
.
float32
)
# Use config.
audio
=
video_input
.
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
self
.
_params
.
task
.
validation_data
.
audio_feature_shape
)
features
[
'audio'
]
=
audio
return
features
@
tf
.
function
def
inference_from_tf_example
(
self
,
encoded_inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
with
tf
.
device
(
'cpu:0'
):
if
self
.
_output_audio
:
inputs
=
tf
.
map_fn
(
self
.
_decode_tf_example
,
(
encoded_inputs
),
fn_output_signature
=
{
video_input
.
IMAGE_KEY
:
tf
.
string
,
self
.
_params
.
task
.
validation_data
.
audio_feature
:
tf
.
float32
})
return
self
.
serve
(
inputs
[
'image'
],
inputs
[
'audio'
])
else
:
inputs
=
tf
.
map_fn
(
self
.
_decode_tf_example
,
(
encoded_inputs
),
fn_output_signature
=
{
video_input
.
IMAGE_KEY
:
tf
.
string
,
})
return
self
.
serve
(
inputs
[
video_input
.
IMAGE_KEY
],
tf
.
zeros
([
1
,
1
]))
@
tf
.
function
def
inference_from_image_tensors
(
self
,
input_frames
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
input_frames
,
tf
.
zeros
([
1
,
1
]))
@
tf
.
function
def
inference_from_image_audio_tensors
(
self
,
input_frames
:
tf
.
Tensor
,
input_audio
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
input_frames
,
input_audio
)
@
tf
.
function
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
raise
NotImplementedError
(
'Video classification do not support image bytes input.'
)
def
serve
(
self
,
input_frames
:
tf
.
Tensor
,
input_audio
:
tf
.
Tensor
):
"""Cast image to float and run inference.
Args:
input_frames: uint8 Tensor of shape [batch_size, None, None, 3]
input_audio: float32
Returns:
Tensor holding classification output logits.
"""
with
tf
.
device
(
'cpu:0'
):
inputs
=
tf
.
map_fn
(
self
.
_preprocess_image
,
(
input_frames
),
fn_output_signature
=
{
'image'
:
tf
.
float32
,
})
if
self
.
_output_audio
:
inputs
.
update
(
tf
.
map_fn
(
self
.
_preprocess_audio
,
(
input_audio
),
fn_output_signature
=
{
'audio'
:
tf
.
float32
}))
logits
=
self
.
inference_step
(
inputs
)
if
self
.
params
.
task
.
train_data
.
is_multilabel
:
probs
=
tf
.
math
.
sigmoid
(
logits
)
else
:
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
key
,
def_name
in
function_keys
.
items
():
if
key
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
,
name
=
'INPUT_FRAMES'
)
signatures
[
def_name
]
=
self
.
inference_from_image_tensors
.
get_concrete_function
(
input_signature
)
elif
key
==
'frames_audio'
:
input_signature
=
[
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
,
name
=
'INPUT_FRAMES'
),
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
params
.
task
.
train_data
.
audio_feature_shape
,
dtype
=
tf
.
float32
,
name
=
'INPUT_AUDIO'
)
]
signatures
[
def_name
]
=
self
.
inference_from_image_audio_tensors
.
get_concrete_function
(
input_signature
)
elif
key
==
'serve_examples'
or
key
==
'tf_example'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
)
signatures
[
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
signatures
official/vision/beta/serving/video_classification_test.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# import io
import
os
import
random
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
video_classification
class
VideoClassificationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
):
params
=
exp_factory
.
get_exp_config
(
'video_classification_ucf101'
)
params
.
task
.
train_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
validation_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
model
.
backbone
.
resnet_3d
.
model_id
=
50
classification_module
=
video_classification
.
VideoClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
8
,
64
,
64
])
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
module
=
None
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
images
=
np
.
random
.
randint
(
low
=
0
,
high
=
255
,
size
=
(
1
,
8
,
64
,
64
,
3
),
dtype
=
np
.
uint8
)
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
return
images
,
images
elif
input_type
==
'tf_example'
:
example
=
tfexample_utils
.
make_video_test_example
(
image_shape
=
(
64
,
64
,
3
),
audio_shape
=
(
20
,
128
),
label
=
random
.
randint
(
0
,
100
)).
SerializeToString
()
images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_decode_tf_example
,
elems
=
tf
.
constant
([
example
]),
fn_output_signature
=
{
video_classification
.
video_input
.
IMAGE_KEY
:
tf
.
string
,
}))
images
=
images
[
video_classification
.
video_input
.
IMAGE_KEY
]
return
[
example
],
images
else
:
raise
ValueError
(
f
'
{
input_type
}
'
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
,
images_tensor
=
self
.
_get_dummy_input
(
input_type
,
module
)
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_preprocess_image
,
elems
=
images_tensor
,
fn_output_signature
=
{
'image'
:
tf
.
float32
,
}))
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
())
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
5ffcc5b6
...
...
@@ -24,7 +24,7 @@ from official.modeling import tf_utils
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_
classification_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.modeling
import
factory
...
...
@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
decoder
=
tfds_factory
.
get_classification_decoder
(
params
.
tfds_name
)
else
:
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
...
...
official/vision/beta/tasks/retinanet.py
View file @
5ffcc5b6
...
...
@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tfds_
detection_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.modeling
import
factory
...
...
@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
"""Build input dataset."""
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_detection_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
5ffcc5b6
...
...
@@ -23,7 +23,7 @@ from official.core import task_factory
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
tfds_
segmentation_decoders
from
official.vision.beta.dataloaders
import
tfds_
factory
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.modeling
import
factory
...
...
@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
:
decoder
=
tfds_segmentation_decoders
.
TFDS_ID_TO_DECODER_MAP
[
params
.
tfds_name
]()
else
:
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
else
:
decoder
=
segmentation_input
.
Decoder
()
...
...
official/vision/beta/train.py
View file @
5ffcc5b6
...
...
@@ -66,4 +66,5 @@ def main(_):
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/vision/beta/train_spatial_partitioning.py
View file @
5ffcc5b6
...
...
@@ -14,6 +14,7 @@
# Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
flags
...
...
@@ -33,19 +34,34 @@ from official.modeling import performance
FLAGS
=
flags
.
FLAGS
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
):
"""Return computation shape to be used for TPUStrategy spatial partition."""
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
:
Sequence
[
int
])
->
Sequence
[
int
]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
if
num_logical_devices
==
2
:
el
if
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
if
num_logical_devices
==
4
:
el
if
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
if
num_logical_devices
==
8
:
el
if
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
if
num_logical_devices
==
16
:
el
if
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
else
:
raise
ValueError
(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16'
%
num_logical_devices
)
def
create_distribution_strategy
(
distribution_strategy
,
...
...
official/vision/image_classification/README.md
View file @
5ffcc5b6
# Image Classification
**Warning:**
the features in the
`image_classification/`
folder have been fully
intergrated into vision/beta. Please use the
[
new code base
](
../beta/README.md
)
.
This folder contains TF 2.0 model examples for image classification:
*
[
MNIST
](
#mnist
)
...
...
official/vision/keras_cv/ops/iou_similarity.py
View file @
5ffcc5b6
...
...
@@ -132,6 +132,9 @@ class IouSimilarity:
Output shape:
[M, N], or [B, M, N]
"""
boxes_1
=
tf
.
cast
(
boxes_1
,
tf
.
float32
)
boxes_2
=
tf
.
cast
(
boxes_2
,
tf
.
float32
)
boxes_1_rank
=
len
(
boxes_1
.
shape
)
boxes_2_rank
=
len
(
boxes_2
.
shape
)
if
boxes_1_rank
<
2
or
boxes_1_rank
>
3
:
...
...
orbit/actions/export_saved_model.py
View file @
5ffcc5b6
...
...
@@ -14,24 +14,32 @@
"""Provides the `ExportSavedModel` action and associated helper classes."""
import
re
from
typing
import
Callable
,
Optional
import
tensorflow
as
tf
def
_id_key
(
filename
):
_
,
id_num
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
def
_find_managed_files
(
base_name
):
r
"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex
=
re
.
compile
(
rf
'
{
re
.
escape
(
base_name
)
}
-\d+$'
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
filenames
=
filter
(
managed_file_regex
.
match
,
filenames
)
return
sorted
(
filenames
,
key
=
_id_key
)
class
_CounterIdFn
:
"""Implements a counter-based ID function for `ExportFileManager`."""
def
__init__
(
self
,
base_name
:
str
):
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
max_counter
=
-
1
for
filename
in
filenames
:
try
:
_
,
file_number
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
max_counter
=
max
(
max_counter
,
int
(
file_number
))
except
ValueError
:
continue
self
.
value
=
max_counter
+
1
managed_files
=
_find_managed_files
(
base_name
)
self
.
value
=
_id_key
(
managed_files
[
-
1
])
+
1
if
managed_files
else
0
def
__call__
(
self
):
output
=
self
.
value
...
...
@@ -82,13 +90,7 @@ class ExportFileManager:
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
"""
def
id_key
(
name
):
_
,
id_num
=
name
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
self
.
_base_name
}
-*'
)
return
sorted
(
filenames
,
key
=
id_key
)
return
_find_managed_files
(
self
.
_base_name
)
def
clean_up
(
self
):
"""Cleans up old files matching `{base_name}-*`.
...
...
orbit/actions/export_saved_model_test.py
View file @
5ffcc5b6
...
...
@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
_id_sorted_file_base_names
(
directory
.
full_path
),
[
'basename-200'
,
'basename-1000'
])
def
test_export_file_manager_managed_files
(
self
):
directory
=
self
.
create_tempdir
()
directory
.
create_file
(
'basename-5'
)
directory
.
create_file
(
'basename-10'
)
directory
.
create_file
(
'basename-50'
)
directory
.
create_file
(
'basename-1000'
)
directory
.
create_file
(
'basename-9'
)
directory
.
create_file
(
'basename-10-suffix'
)
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
manager
=
actions
.
ExportFileManager
(
base_name
,
max_to_keep
=
3
)
self
.
assertLen
(
manager
.
managed_files
,
5
)
self
.
assertEqual
(
manager
.
next_name
(),
f
'
{
base_name
}
-1001'
)
manager
.
clean_up
()
self
.
assertEqual
(
manager
.
managed_files
,
[
f
'
{
base_name
}
-10'
,
f
'
{
base_name
}
-50'
,
f
'
{
base_name
}
-1000'
])
def
test_export_saved_model
(
self
):
directory
=
self
.
create_tempdir
()
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
...
...
research/delf/delf/python/datasets/generic_dataset.py
0 → 100644
View file @
5ffcc5b6
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for generic image dataset creation."""
import
os
from
delf.python.datasets
import
utils
class
ImagesFromList
():
"""A generic data loader that loads images from a list.
Supports images of different sizes.
"""
def
__init__
(
self
,
root
,
image_paths
,
imsize
=
None
,
bounding_boxes
=
None
,
loader
=
utils
.
default_loader
):
"""ImagesFromList object initialization.
Args:
root: String, root directory path.
image_paths: List, relative image paths as strings.
imsize: Integer, defines the maximum size of longer image side.
bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images.
loader: Callable, a function to load an image given its path.
Raises:
ValueError: Raised if `image_paths` list is empty.
"""
# List of the full image filenames.
images_filenames
=
[
os
.
path
.
join
(
root
,
image_path
)
for
image_path
in
image_paths
]
if
not
images_filenames
:
raise
ValueError
(
"Dataset contains 0 images."
)
self
.
root
=
root
self
.
images
=
image_paths
self
.
imsize
=
imsize
self
.
images_filenames
=
images_filenames
self
.
bounding_boxes
=
bounding_boxes
self
.
loader
=
loader
def
__getitem__
(
self
,
index
):
"""Called to load an image at the given `index`.
Args:
index: Integer, image index.
Returns:
image: Tensor, loaded image.
"""
path
=
self
.
images_filenames
[
index
]
if
self
.
bounding_boxes
is
not
None
:
img
=
self
.
loader
(
path
,
self
.
imsize
,
self
.
bounding_boxes
[
index
])
else
:
img
=
self
.
loader
(
path
,
self
.
imsize
)
return
img
def
__len__
(
self
):
"""Implements the built-in function len().
Returns:
len: Number of images in the dataset.
"""
return
len
(
self
.
images_filenames
)
research/delf/delf/python/datasets/generic_dataset_test.py
0 → 100644
View file @
5ffcc5b6
# Lint as: python3
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for generic dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
flags
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
delf.python.datasets
import
generic_dataset
FLAGS
=
flags
.
FLAGS
class
GenericDatasetTest
(
tf
.
test
.
TestCase
):
"""Test functions for generic dataset."""
def
testGenericDataset
(
self
):
"""Tests loading dummy images from list."""
# Number of images to be created.
n
=
2
image_names
=
[]
# Create and save `n` dummy images.
for
i
in
range
(
n
):
dummy_image
=
np
.
random
.
rand
(
1024
,
750
,
3
)
*
255
img_out
=
Image
.
fromarray
(
dummy_image
.
astype
(
'uint8'
)).
convert
(
'RGB'
)
filename
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'test_image_{}.jpg'
.
format
(
i
))
img_out
.
save
(
filename
)
image_names
.
append
(
'test_image_{}.jpg'
.
format
(
i
))
data
=
generic_dataset
.
ImagesFromList
(
root
=
FLAGS
.
test_tmpdir
,
image_paths
=
image_names
,
imsize
=
1024
)
self
.
assertLen
(
data
,
n
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/delf/delf/python/datasets/sfm120k/__init__.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module exposing Sfm120k dataset for training."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# pylint: disable=unused-import
from
delf.python.datasets.sfm120k
import
sfm120k
# pylint: enable=unused-import
research/delf/delf/python/datasets/sfm120k/dataset_download.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) download function."""
import
os
import
tensorflow
as
tf
def
download_train
(
data_dir
):
"""Checks, and, if required, downloads the necessary files for the training.
Checks if the data necessary for running the example training script exist.
If not, it downloads it in the following folder structure:
DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db
files.
DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db
files.
"""
# Create data folder if does not exist.
if
not
tf
.
io
.
gfile
.
exists
(
data_dir
):
tf
.
io
.
gfile
.
mkdir
(
data_dir
)
# Create datasets folder if does not exist.
datasets_dir
=
os
.
path
.
join
(
data_dir
,
'train'
)
if
not
tf
.
io
.
gfile
.
exists
(
datasets_dir
):
tf
.
io
.
gfile
.
mkdir
(
datasets_dir
)
# Download folder train/retrieval-SfM-120k/.
src_dir
=
'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/ims'
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-120k'
,
'ims'
)
download_file
=
'ims.tar.gz'
if
not
tf
.
io
.
gfile
.
exists
(
dst_dir
):
src_file
=
os
.
path
.
join
(
src_dir
,
download_file
)
dst_file
=
os
.
path
.
join
(
dst_dir
,
download_file
)
print
(
'>> Image directory does not exist. Creating: {}'
.
format
(
dst_dir
))
tf
.
io
.
gfile
.
makedirs
(
dst_dir
)
print
(
'>> Downloading ims.tar.gz...'
)
os
.
system
(
'wget {} -O {}'
.
format
(
src_file
,
dst_file
))
print
(
'>> Extracting {}...'
.
format
(
dst_file
))
os
.
system
(
'tar -zxf {} -C {}'
.
format
(
dst_file
,
dst_dir
))
print
(
'>> Extracted, deleting {}...'
.
format
(
dst_file
))
os
.
system
(
'rm {}'
.
format
(
dst_file
))
# Create symlink for train/retrieval-SfM-30k/.
dst_dir_old
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-120k'
,
'ims'
)
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-30k'
,
'ims'
)
if
not
(
tf
.
io
.
gfile
.
exists
(
dst_dir
)
or
os
.
path
.
islink
(
dst_dir
)):
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
join
(
datasets_dir
,
'retrieval-SfM-30k'
))
os
.
system
(
'ln -s {} {}'
.
format
(
dst_dir_old
,
dst_dir
))
print
(
'>> Created symbolic link from retrieval-SfM-120k/ims to '
'retrieval-SfM-30k/ims'
)
# Download db files.
src_dir
=
'http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/dbs'
datasets
=
[
'retrieval-SfM-120k'
,
'retrieval-SfM-30k'
]
for
dataset
in
datasets
:
dst_dir
=
os
.
path
.
join
(
datasets_dir
,
dataset
)
if
dataset
==
'retrieval-SfM-120k'
:
download_files
=
[
'{}.pkl'
.
format
(
dataset
),
'{}-whiten.pkl'
.
format
(
dataset
)]
download_eccv2020
=
'{}-val-eccv2020.pkl'
.
format
(
dataset
)
elif
dataset
==
'retrieval-SfM-30k'
:
download_files
=
[
'{}-whiten.pkl'
.
format
(
dataset
)]
download_eccv2020
=
None
if
not
tf
.
io
.
gfile
.
exists
(
dst_dir
):
print
(
'>> Dataset directory does not exist. Creating: {}'
.
format
(
dst_dir
))
tf
.
io
.
gfile
.
mkdir
(
dst_dir
)
for
i
in
range
(
len
(
download_files
)):
src_file
=
os
.
path
.
join
(
src_dir
,
download_files
[
i
])
dst_file
=
os
.
path
.
join
(
dst_dir
,
download_files
[
i
])
if
not
os
.
path
.
isfile
(
dst_file
):
print
(
'>> DB file {} does not exist. Downloading...'
.
format
(
download_files
[
i
]))
os
.
system
(
'wget {} -O {}'
.
format
(
src_file
,
dst_file
))
if
download_eccv2020
:
eccv2020_dst_file
=
os
.
path
.
join
(
dst_dir
,
download_eccv2020
)
if
not
os
.
path
.
isfile
(
eccv2020_dst_file
):
eccv2020_src_dir
=
\
"http://ptak.felk.cvut.cz/personal/toliageo/share/how/dataset/"
eccv2020_dst_file
=
os
.
path
.
join
(
dst_dir
,
download_eccv2020
)
eccv2020_src_file
=
os
.
path
.
join
(
eccv2020_src_dir
,
download_eccv2020
)
os
.
system
(
'wget {} -O {}'
.
format
(
eccv2020_src_file
,
eccv2020_dst_file
))
research/delf/delf/python/datasets/sfm120k/sfm120k.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Structure-from-Motion dataset (Sfm120k) module.
[1] From Single Image Query to Detailed 3D Reconstruction.
Johannes L. Schonberger, Filip Radenovic, Ondrej Chum, Jan-Michael Frahm.
The related paper can be found at: https://ieeexplore.ieee.org/document/7299148.
"""
import
os
import
pickle
import
tensorflow
as
tf
from
delf.python.datasets
import
tuples_dataset
from
delf.python.datasets
import
utils
def
id2filename
(
image_id
,
prefix
):
"""Creates a training image path out of its id name.
Used for the image mapping in the Sfm120k datset.
Args:
image_id: String, image id.
prefix: String, root directory where images are saved.
Returns:
filename: String, full image filename.
"""
if
prefix
:
return
os
.
path
.
join
(
prefix
,
image_id
[
-
2
:],
image_id
[
-
4
:
-
2
],
image_id
[
-
6
:
-
4
],
image_id
)
else
:
return
os
.
path
.
join
(
image_id
[
-
2
:],
image_id
[
-
4
:
-
2
],
image_id
[
-
6
:
-
4
],
image_id
)
class
_Sfm120k
(
tuples_dataset
.
TuplesDataset
):
"""Structure-from-Motion (Sfm120k) dataset instance.
The dataset contains the image names lists for training and validation,
the cluster ID (3D model ID) for each image and indices forming
query-positive pairs of images. The images are loaded per epoch and resized
on the fly to the desired dimensionality.
"""
def
__init__
(
self
,
mode
,
data_root
,
imsize
=
None
,
num_negatives
=
5
,
num_queries
=
2000
,
pool_size
=
20000
,
loader
=
utils
.
default_loader
,
eccv2020
=
False
):
"""Structure-from-Motion (Sfm120k) dataset initialization.
Args:
mode: Either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Raises:
ValueError: Raised if `mode` is not one of 'train' or 'val'.
"""
if
mode
not
in
[
'train'
,
'val'
]:
raise
ValueError
(
"`mode` argument should be either 'train' or 'val', passed as a "
"String."
)
# Setting up the paths for the dataset.
if
eccv2020
:
name
=
"retrieval-SfM-120k-val-eccv2020"
else
:
name
=
"retrieval-SfM-120k"
db_root
=
os
.
path
.
join
(
data_root
,
'train/retrieval-SfM-120k'
)
ims_root
=
os
.
path
.
join
(
db_root
,
'ims/'
)
# Loading the dataset db file.
db_filename
=
os
.
path
.
join
(
db_root
,
'{}.pkl'
.
format
(
name
))
with
tf
.
io
.
gfile
.
GFile
(
db_filename
,
'rb'
)
as
f
:
db
=
pickle
.
load
(
f
)[
mode
]
# Setting full paths for the dataset images.
self
.
images
=
[
id2filename
(
img_name
,
None
)
for
img_name
in
db
[
'cids'
]]
# Initializing tuples dataset.
super
().
__init__
(
name
,
mode
,
db_root
,
imsize
,
num_negatives
,
num_queries
,
pool_size
,
loader
,
ims_root
)
def
Sfm120kInfo
(
self
):
"""Metadata for the Sfm120k dataset.
The dataset contains the image names lists for training and
validation, the cluster ID (3D model ID) for each image and indices
forming query-positive pairs of images. The images are loaded per epoch
and resized on the fly to the desired dimensionality.
Returns:
info: dictionary with the dataset parameters.
"""
info
=
{
'train'
:
{
'clusters'
:
91642
,
'pidxs'
:
181697
,
'qidxs'
:
181697
},
'val'
:
{
'clusters'
:
6403
,
'pidxs'
:
1691
,
'qidxs'
:
1691
}}
return
info
def
CreateDataset
(
mode
,
data_root
,
imsize
=
None
,
num_negatives
=
5
,
num_queries
=
2000
,
pool_size
=
20000
,
loader
=
utils
.
default_loader
,
eccv2020
=
False
):
'''Creates Structure-from-Motion (Sfm120k) dataset.
Args:
mode: String, either 'train' or 'val'.
data_root: Path to the root directory of the dataset.
imsize: Integer, defines the maximum size of longer image side.
num_negatives: Integer, number of negative images per one query.
num_queries: Integer, number of query images.
pool_size: Integer, size of the negative image pool, from where the
hard-negative images are chosen.
loader: Callable, a function to load an image given its path.
eccv2020: Bool, whether to use a new validation dataset used with ECCV
2020 paper (https://arxiv.org/abs/2007.13172).
Returns:
sfm120k: Sfm120k dataset instance.
'''
return
_Sfm120k
(
mode
,
data_root
,
imsize
,
num_negatives
,
num_queries
,
pool_size
,
loader
,
eccv2020
)
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment