Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
44e7092c
".github/vscode:/vscode.git/clone" did not exist on "237cebffd80a25b46d0ff76eff9925d800c997d0"
Commit
44e7092c
authored
Feb 01, 2021
by
stephenwu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into AXg
parents
431a9ca3
59434199
Changes
113
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
540 additions
and
38 deletions
+540
-38
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+7
-0
official/vision/beta/train.py
official/vision/beta/train.py
+2
-0
official/vision/keras_cv/layers/deeplab.py
official/vision/keras_cv/layers/deeplab.py
+5
-2
research/object_detection/builders/dataset_builder.py
research/object_detection/builders/dataset_builder.py
+21
-6
research/object_detection/builders/decoder_builder.py
research/object_detection/builders/decoder_builder.py
+3
-1
research/object_detection/builders/decoder_builder_test.py
research/object_detection/builders/decoder_builder_test.py
+24
-0
research/object_detection/builders/hyperparams_builder.py
research/object_detection/builders/hyperparams_builder.py
+19
-4
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+6
-2
research/object_detection/core/freezable_batch_norm_tf2_test.py
...ch/object_detection/core/freezable_batch_norm_tf2_test.py
+34
-14
research/object_detection/core/freezable_sync_batch_norm.py
research/object_detection/core/freezable_sync_batch_norm.py
+70
-0
research/object_detection/core/model.py
research/object_detection/core/model.py
+16
-1
research/object_detection/core/post_processing.py
research/object_detection/core/post_processing.py
+6
-0
research/object_detection/core/preprocessor.py
research/object_detection/core/preprocessor.py
+40
-1
research/object_detection/core/preprocessor_test.py
research/object_detection/core/preprocessor_test.py
+64
-0
research/object_detection/core/standard_fields.py
research/object_detection/core/standard_fields.py
+11
-0
research/object_detection/data_decoders/tf_example_decoder.py
...arch/object_detection/data_decoders/tf_example_decoder.py
+91
-1
research/object_detection/data_decoders/tf_example_decoder_test.py
...object_detection/data_decoders/tf_example_decoder_test.py
+118
-0
research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py
...et_tools/context_rcnn/generate_detection_data_tf2_test.py
+1
-2
research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py
...et_tools/context_rcnn/generate_embedding_data_tf2_test.py
+1
-2
research/object_detection/exporter_lib_tf2_test.py
research/object_detection/exporter_lib_tf2_test.py
+1
-2
No files found.
official/vision/beta/tasks/video_classification.py
View file @
44e7092c
...
...
@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task):
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
num_test_clips
=
self
.
task_config
.
validation_data
.
num_test_clips
num_test_crops
=
self
.
task_config
.
validation_data
.
num_test_crops
num_test_views
=
num_test_clips
*
num_test_crops
if
num_test_views
>
1
:
# Averaging output probabilities across multiples views.
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
num_test_views
,
outputs
.
shape
[
-
1
]])
outputs
=
tf
.
reduce_mean
(
outputs
,
axis
=
1
)
return
outputs
official/vision/beta/train.py
View file @
44e7092c
...
...
@@ -63,6 +63,8 @@ def main(_):
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/vision/keras_cv/layers/deeplab.py
View file @
44e7092c
...
...
@@ -142,7 +142,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
epsilon
=
self
.
batchnorm_epsilon
),
tf
.
keras
.
layers
.
Activation
(
self
.
activation
),
tf
.
keras
.
layers
.
experimental
.
preprocessing
.
Resizing
(
height
,
width
,
interpolation
=
self
.
interpolation
)
height
,
width
,
interpolation
=
self
.
interpolation
,
dtype
=
tf
.
float32
)
]))
self
.
aspp_layers
.
append
(
pool_sequential
)
...
...
@@ -165,7 +168,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
training
=
tf
.
keras
.
backend
.
learning_phase
()
result
=
[]
for
layer
in
self
.
aspp_layers
:
result
.
append
(
layer
(
inputs
,
training
=
training
))
result
.
append
(
tf
.
cast
(
layer
(
inputs
,
training
=
training
)
,
inputs
.
dtype
)
)
result
=
tf
.
concat
(
result
,
axis
=-
1
)
result
=
self
.
projection
(
result
,
training
=
training
)
return
result
...
...
research/object_detection/builders/dataset_builder.py
View file @
44e7092c
...
...
@@ -27,6 +27,7 @@ from __future__ import division
from
__future__
import
print_function
import
functools
import
math
import
tensorflow.compat.v1
as
tf
from
object_detection.builders
import
decoder_builder
...
...
@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset):
def
_read_dataset_internal
(
file_read_func
,
input_files
,
num_readers
,
config
,
filename_shard_fn
=
None
):
"""Reads a dataset, and handles repetition and shuffling.
...
...
@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func,
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
num_readers: Number of readers to use.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
...
...
@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func,
if
not
filenames
:
raise
RuntimeError
(
'Did not find any input files matching the glob pattern '
'{}'
.
format
(
input_files
))
num_readers
=
config
.
num_readers
if
num_readers
>
len
(
filenames
):
num_readers
=
len
(
filenames
)
tf
.
logging
.
warning
(
'num_readers has been reduced to %d to match input file '
...
...
@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
tf
.
logging
.
info
(
'Sampling from datasets %s with weights %s'
%
(
input_files
,
config
.
sample_from_datasets_weights
))
records_datasets
=
[]
for
input_file
in
input_files
:
dataset_weights
=
[]
for
i
,
input_file
in
enumerate
(
input_files
):
weight
=
config
.
sample_from_datasets_weights
[
i
]
num_readers
=
math
.
ceil
(
config
.
num_readers
*
weight
/
sum
(
config
.
sample_from_datasets_weights
))
tf
.
logging
.
info
(
'Num readers for dataset [%s]: %d'
,
input_file
,
num_readers
)
if
num_readers
==
0
:
tf
.
logging
.
info
(
'Skipping dataset due to zero weights: %s'
,
input_file
)
continue
tf
.
logging
.
info
(
'Num readers for dataset [%s]: %d'
,
input_file
,
num_readers
)
records_dataset
=
_read_dataset_internal
(
file_read_func
,
[
input_file
],
config
,
filename_shard_fn
)
num_readers
,
config
,
filename_shard_fn
)
dataset_weights
.
append
(
weight
)
records_datasets
.
append
(
records_dataset
)
dataset_weights
=
list
(
config
.
sample_from_datasets_weights
)
return
tf
.
data
.
experimental
.
sample_from_datasets
(
records_datasets
,
dataset_weights
)
else
:
tf
.
logging
.
info
(
'Reading unweighted datasets: %s'
%
input_files
)
return
_read_dataset_internal
(
file_read_func
,
input_files
,
config
,
filename_shard_fn
)
return
_read_dataset_internal
(
file_read_func
,
input_files
,
config
.
num_readers
,
config
,
filename_shard_fn
)
def
shard_function_for_context
(
input_context
):
...
...
research/object_detection/builders/decoder_builder.py
View file @
44e7092c
...
...
@@ -60,7 +60,9 @@ def build(input_reader_config):
num_keypoints
=
input_reader_config
.
num_keypoints
,
expand_hierarchy_labels
=
input_reader_config
.
expand_labels_hierarchy
,
load_dense_pose
=
input_reader_config
.
load_dense_pose
,
load_track_id
=
input_reader_config
.
load_track_id
)
load_track_id
=
input_reader_config
.
load_track_id
,
load_keypoint_depth_features
=
input_reader_config
.
load_keypoint_depth_features
)
return
decoder
elif
input_type
==
input_reader_pb2
.
InputType
.
Value
(
'TF_SEQUENCE_EXAMPLE'
):
decoder
=
tf_sequence_example_decoder
.
TfSequenceExampleDecoder
(
...
...
research/object_detection/builders/decoder_builder_test.py
View file @
44e7092c
...
...
@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase):
'image/object/bbox/ymax'
:
dataset_util
.
float_list_feature
([
1.0
]),
'image/object/class/label'
:
dataset_util
.
int64_list_feature
([
2
]),
'image/object/mask'
:
dataset_util
.
float_list_feature
(
flat_mask
),
'image/object/keypoint/x'
:
dataset_util
.
float_list_feature
([
1.0
,
1.0
]),
'image/object/keypoint/y'
:
dataset_util
.
float_list_feature
([
1.0
,
1.0
])
}
if
has_additional_channels
:
additional_channels_key
=
'image/additional_channels/encoded'
...
...
@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase):
masks
=
self
.
execute_cpu
(
graph_fn
,
[])
self
.
assertAllEqual
((
1
,
4
,
5
),
masks
.
shape
)
def
test_build_tf_record_input_reader_and_load_keypoint_depth
(
self
):
input_reader_text_proto
=
"""
load_keypoint_depth_features: true
num_keypoints: 2
tf_record_input_reader {}
"""
input_reader_proto
=
input_reader_pb2
.
InputReader
()
text_format
.
Parse
(
input_reader_text_proto
,
input_reader_proto
)
decoder
=
decoder_builder
.
build
(
input_reader_proto
)
serialized_example
=
self
.
_make_serialized_tf_example
()
def
graph_fn
():
tensor_dict
=
decoder
.
decode
(
serialized_example
)
return
(
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
],
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
])
(
kpts_depths
,
kpts_depth_weights
)
=
self
.
execute_cpu
(
graph_fn
,
[])
self
.
assertAllEqual
((
1
,
2
),
kpts_depths
.
shape
)
self
.
assertAllEqual
((
1
,
2
),
kpts_depth_weights
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/builders/hyperparams_builder.py
View file @
44e7092c
...
...
@@ -20,7 +20,11 @@ import tf_slim as slim
from
object_detection.core
import
freezable_batch_norm
from
object_detection.protos
import
hyperparams_pb2
from
object_detection.utils
import
context_manager
from
object_detection.utils
import
tf_version
# pylint: disable=g-import-not-at-top
if
tf_version
.
is_tf2
():
from
object_detection.core
import
freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
...
...
@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object):
'hyperparams_pb.Hyperparams.'
)
self
.
_batch_norm_params
=
None
self
.
_use_sync_batch_norm
=
False
if
hyperparams_config
.
HasField
(
'batch_norm'
):
self
.
_batch_norm_params
=
_build_keras_batch_norm_params
(
hyperparams_config
.
batch_norm
)
elif
hyperparams_config
.
HasField
(
'sync_batch_norm'
):
self
.
_use_sync_batch_norm
=
True
self
.
_batch_norm_params
=
_build_keras_batch_norm_params
(
hyperparams_config
.
sync_batch_norm
)
self
.
_force_use_bias
=
hyperparams_config
.
force_use_bias
self
.
_activation_fn
=
_build_activation_fn
(
hyperparams_config
.
activation
)
...
...
@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object):
is False)
"""
if
self
.
use_batch_norm
():
return
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
**
self
.
batch_norm_params
(
**
overrides
)
)
if
self
.
_use_sync_batch_norm
:
return
freezable_sync_batch_norm
.
FreezableSyncBatchNorm
(
training
=
training
,
**
self
.
batch_norm_params
(
**
overrides
))
else
:
return
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
**
self
.
batch_norm_params
(
**
overrides
))
else
:
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
identity
)
...
...
@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training):
raise
ValueError
(
'Hyperparams force_use_bias only supported by '
'KerasLayerHyperparams.'
)
if
hyperparams_config
.
HasField
(
'sync_batch_norm'
):
raise
ValueError
(
'Hyperparams sync_batch_norm only supported by '
'KerasLayerHyperparams.'
)
normalizer_fn
=
None
batch_norm_params
=
None
if
hyperparams_config
.
HasField
(
'batch_norm'
):
...
...
research/object_detection/builders/model_builder.py
View file @
44e7092c
...
...
@@ -1039,7 +1039,10 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
if
center_net_config
.
HasField
(
'temporal_offset_task'
):
temporal_offset_params
=
temporal_offset_proto_to_params
(
center_net_config
.
temporal_offset_task
)
non_max_suppression_fn
=
None
if
center_net_config
.
HasField
(
'post_processing'
):
non_max_suppression_fn
,
_
=
post_processing_builder
.
build
(
center_net_config
.
post_processing
)
return
center_net_meta_arch
.
CenterNetMetaArch
(
is_training
=
is_training
,
add_summaries
=
add_summaries
,
...
...
@@ -1054,7 +1057,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
track_params
=
track_params
,
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
,
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
)
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
,
non_max_suppression_fn
=
non_max_suppression_fn
)
def
_build_center_net_feature_extractor
(
...
...
research/object_detection/core/freezable_batch_norm_tf2_test.py
View file @
44e7092c
...
...
@@ -17,25 +17,40 @@
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
from
absl.testing
import
parameterized
import
numpy
as
np
from
six.moves
import
zip
import
tensorflow
.compat.v1
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
freezable_batch_norm
from
object_detection.utils
import
tf_version
# pylint: disable=g-import-not-at-top
if
tf_version
.
is_tf2
():
from
object_detection.core
import
freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
FreezableBatchNormTest
(
tf
.
test
.
TestCase
):
class
FreezableBatchNormTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
"""Tests for FreezableBatchNorm operations."""
def
_build_model
(
self
,
training
=
None
):
def
_build_model
(
self
,
use_sync_batch_norm
,
training
=
None
):
model
=
tf
.
keras
.
models
.
Sequential
()
norm
=
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
input_shape
=
(
10
,),
momentum
=
0.8
)
norm
=
None
if
use_sync_batch_norm
:
norm
=
freezable_sync_batch_norm
.
FreezableSyncBatchNorm
(
training
=
training
,
input_shape
=
(
10
,),
momentum
=
0.8
)
else
:
norm
=
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
input_shape
=
(
10
,),
momentum
=
0.8
)
model
.
add
(
norm
)
return
model
,
norm
...
...
@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase):
for
source
,
target
in
zip
(
source_weights
,
target_weights
):
target
.
assign
(
source
)
def
_train_freezable_batch_norm
(
self
,
training_mean
,
training_var
):
model
,
_
=
self
.
_build_model
()
def
_train_freezable_batch_norm
(
self
,
training_mean
,
training_var
,
use_sync_batch_norm
):
model
,
_
=
self
.
_build_model
(
use_sync_batch_norm
=
use_sync_batch_norm
)
model
.
compile
(
loss
=
'mse'
,
optimizer
=
'sgd'
)
# centered on training_mean, variance training_var
...
...
@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
np
.
testing
.
assert_allclose
(
out
.
numpy
().
mean
(),
0.0
,
atol
=
1.5e-1
)
np
.
testing
.
assert_allclose
(
out
.
numpy
().
std
(),
1.0
,
atol
=
1.5e-1
)
def
test_batchnorm_freezing_training_none
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_batchnorm_freezing_training_none
(
self
,
use_sync_batch_norm
):
training_mean
=
5.0
training_var
=
10.0
...
...
@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
training_var
)
training_var
,
use_sync_batch_norm
)
# Load the batch norm weights, freezing training to True.
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the batch statistics.
model
,
norm
=
self
.
_build_model
(
training
=
True
)
model
,
norm
=
self
.
_build_model
(
use_sync_batch_norm
,
training
=
True
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
# centered on testing_mean, variance testing_var
...
...
@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
testing_mean
,
testing_var
,
training_arg
,
training_mean
,
training_var
)
def
test_batchnorm_freezing_training_false
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_batchnorm_freezing_training_false
(
self
,
use_sync_batch_norm
):
training_mean
=
5.0
training_var
=
10.0
...
...
@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
training_var
)
training_var
,
use_sync_batch_norm
)
# Load the batch norm back up, freezing training to False.
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the training data's statistics.
model
,
norm
=
self
.
_build_model
(
training
=
False
)
model
,
norm
=
self
.
_build_model
(
use_sync_batch_norm
,
training
=
False
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
# centered on testing_mean, variance testing_var
...
...
research/object_detection/core/freezable_sync_batch_norm.py
0 → 100644
View file @
44e7092c
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A freezable batch norm layer that uses Keras sync batch normalization."""
import
tensorflow
as
tf
class
FreezableSyncBatchNorm
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
):
"""Sync Batch normalization layer (Ioffe and Szegedy, 2014).
This is a `freezable` batch norm layer that supports setting the `training`
parameter in the __init__ method rather than having to set it either via
the Keras learning phase or via the `call` method parameter. This layer will
forward all other parameters to the Keras `SyncBatchNormalization` layer
This is class is necessary because Object Detection model training sometimes
requires batch normalization layers to be `frozen` and used as if it was
evaluation time, despite still training (and potentially using dropout layers)
Like the default Keras SyncBatchNormalization layer, this will normalize the
activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
References:
- [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def
__init__
(
self
,
training
=
None
,
**
kwargs
):
"""Constructor.
Args:
training: If False, the layer will normalize using the moving average and
std. dev, without updating the learned avg and std. dev.
If None or True, the layer will follow the keras SyncBatchNormalization
layer strategy of checking the Keras learning phase at `call` time to
decide what to do.
**kwargs: The keyword arguments to forward to the keras
SyncBatchNormalization layer constructor.
"""
super
(
FreezableSyncBatchNorm
,
self
).
__init__
(
**
kwargs
)
self
.
_training
=
training
def
call
(
self
,
inputs
,
training
=
None
):
# Override the call arg only if the batchnorm is frozen. (Ignore None)
if
self
.
_training
is
False
:
# pylint: disable=g-bool-id-comparison
training
=
self
.
_training
return
super
(
FreezableSyncBatchNorm
,
self
).
call
(
inputs
,
training
=
training
)
research/object_detection/core/model.py
View file @
44e7092c
...
...
@@ -315,7 +315,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
is_annotated_list
=
None
,
groundtruth_labeled_classes
=
None
,
groundtruth_verified_neg_classes
=
None
,
groundtruth_not_exhaustive_classes
=
None
):
groundtruth_not_exhaustive_classes
=
None
,
groundtruth_keypoint_depths_list
=
None
,
groundtruth_keypoint_depth_weights_list
=
None
):
"""Provide groundtruth tensors.
Args:
...
...
@@ -379,6 +381,11 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_not_exhaustive_classes: A list of 1-D tf.float32 tensors of
shape [num_classes], containing a K-hot representation of classes
which don't have all of their instances marked exhaustively.
groundtruth_keypoint_depths_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing keypoint relative depths.
groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing the weights of the
relative depths.
"""
self
.
_groundtruth_lists
[
fields
.
BoxListFields
.
boxes
]
=
groundtruth_boxes_list
self
.
_groundtruth_lists
[
...
...
@@ -399,6 +406,14 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
self
.
_groundtruth_lists
[
fields
.
BoxListFields
.
keypoint_visibilities
]
=
(
groundtruth_keypoint_visibilities_list
)
if
groundtruth_keypoint_depths_list
:
self
.
_groundtruth_lists
[
fields
.
BoxListFields
.
keypoint_depths
]
=
(
groundtruth_keypoint_depths_list
)
if
groundtruth_keypoint_depth_weights_list
:
self
.
_groundtruth_lists
[
fields
.
BoxListFields
.
keypoint_depth_weights
]
=
(
groundtruth_keypoint_depth_weights_list
)
if
groundtruth_dp_num_points_list
:
self
.
_groundtruth_lists
[
fields
.
BoxListFields
.
densepose_num_points
]
=
(
...
...
research/object_detection/core/post_processing.py
View file @
44e7092c
...
...
@@ -26,6 +26,7 @@ import tensorflow.compat.v1 as tf
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list_ops
from
object_detection.core
import
keypoint_ops
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.utils
import
shape_utils
...
...
@@ -379,6 +380,11 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
if
change_coordinate_frame
:
sorted_boxes
=
box_list_ops
.
change_coordinate_frame
(
sorted_boxes
,
clip_window
)
if
sorted_boxes
.
has_field
(
fields
.
BoxListFields
.
keypoints
):
sorted_keypoints
=
sorted_boxes
.
get_field
(
fields
.
BoxListFields
.
keypoints
)
sorted_keypoints
=
keypoint_ops
.
change_coordinate_frame
(
sorted_keypoints
,
clip_window
)
sorted_boxes
.
set_field
(
fields
.
BoxListFields
.
keypoints
,
sorted_keypoints
)
return
sorted_boxes
,
num_valid_nms_boxes_cumulative
...
...
research/object_detection/core/preprocessor.py
View file @
44e7092c
...
...
@@ -571,6 +571,8 @@ def random_horizontal_flip(image,
keypoint_visibilities
=
None
,
densepose_part_ids
=
None
,
densepose_surface_coords
=
None
,
keypoint_depths
=
None
,
keypoint_depth_weights
=
None
,
keypoint_flip_permutation
=
None
,
probability
=
0.5
,
seed
=
None
,
...
...
@@ -602,6 +604,12 @@ def random_horizontal_flip(image,
(y, x) are the normalized image coordinates for a
sampled point, and (v, u) is the surface
coordinate for the part.
keypoint_depths: (optional) rank 2 float32 tensor with shape [num_instances,
num_keypoints] representing the relative depth of the
keypoints.
keypoint_depth_weights: (optional) rank 2 float32 tensor with shape
[num_instances, num_keypoints] representing the
weights of the relative depth of the keypoints.
keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation.
probability: the probability of performing this augmentation.
...
...
@@ -631,6 +639,10 @@ def random_horizontal_flip(image,
[num_instances, num_points].
densepose_surface_coords: rank 3 float32 tensor with shape
[num_instances, num_points, 4].
keypoint_depths: rank 2 float32 tensor with shape [num_instances,
num_keypoints]
keypoint_depth_weights: rank 2 float32 tensor with shape [num_instances,
num_keypoints].
Raises:
ValueError: if keypoints are provided but keypoint_flip_permutation is not.
...
...
@@ -708,6 +720,21 @@ def random_horizontal_flip(image,
lambda
:
(
densepose_part_ids
,
densepose_surface_coords
))
result
.
extend
(
densepose_tensors
)
# flip keypoint depths and weights.
if
(
keypoint_depths
is
not
None
and
keypoint_flip_permutation
is
not
None
):
kpt_flip_perm
=
keypoint_flip_permutation
keypoint_depths
=
tf
.
cond
(
do_a_flip_random
,
lambda
:
tf
.
gather
(
keypoint_depths
,
kpt_flip_perm
,
axis
=
1
),
lambda
:
keypoint_depths
)
keypoint_depth_weights
=
tf
.
cond
(
do_a_flip_random
,
lambda
:
tf
.
gather
(
keypoint_depth_weights
,
kpt_flip_perm
,
axis
=
1
),
lambda
:
keypoint_depth_weights
)
result
.
append
(
keypoint_depths
)
result
.
append
(
keypoint_depth_weights
)
return
tuple
(
result
)
...
...
@@ -4293,7 +4320,8 @@ def get_default_func_arg_map(include_label_weights=True,
include_instance_masks
=
False
,
include_keypoints
=
False
,
include_keypoint_visibilities
=
False
,
include_dense_pose
=
False
):
include_dense_pose
=
False
,
include_keypoint_depths
=
False
):
"""Returns the default mapping from a preprocessor function to its args.
Args:
...
...
@@ -4311,6 +4339,8 @@ def get_default_func_arg_map(include_label_weights=True,
the keypoint visibilities, too.
include_dense_pose: If True, preprocessing functions will modify the
DensePose labels, too.
include_keypoint_depths: If True, preprocessing functions will modify the
keypoint depth labels, too.
Returns:
A map from preprocessing functions to the arguments they receive.
...
...
@@ -4353,6 +4383,13 @@ def get_default_func_arg_map(include_label_weights=True,
fields
.
InputDataFields
.
groundtruth_dp_part_ids
)
groundtruth_dp_surface_coords
=
(
fields
.
InputDataFields
.
groundtruth_dp_surface_coords
)
groundtruth_keypoint_depths
=
None
groundtruth_keypoint_depth_weights
=
None
if
include_keypoint_depths
:
groundtruth_keypoint_depths
=
(
fields
.
InputDataFields
.
groundtruth_keypoint_depths
)
groundtruth_keypoint_depth_weights
=
(
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
)
prep_func_arg_map
=
{
normalize_image
:
(
fields
.
InputDataFields
.
image
,),
...
...
@@ -4364,6 +4401,8 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_keypoint_visibilities
,
groundtruth_dp_part_ids
,
groundtruth_dp_surface_coords
,
groundtruth_keypoint_depths
,
groundtruth_keypoint_depth_weights
,
),
random_vertical_flip
:
(
fields
.
InputDataFields
.
image
,
...
...
research/object_detection/core/preprocessor_test.py
View file @
44e7092c
...
...
@@ -105,6 +105,17 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
])
return
keypoints
,
keypoint_visibilities
def
createTestKeypointDepths
(
self
):
keypoint_depths
=
tf
.
constant
([
[
1.0
,
0.9
,
0.8
],
[
0.7
,
0.6
,
0.5
]
],
dtype
=
tf
.
float32
)
keypoint_depth_weights
=
tf
.
constant
([
[
0.5
,
0.6
,
0.7
],
[
0.8
,
0.9
,
1.0
]
],
dtype
=
tf
.
float32
)
return
keypoint_depths
,
keypoint_depth_weights
def
createTestKeypointsInsideCrop
(
self
):
keypoints
=
np
.
array
([
[[
0.4
,
0.4
],
[
0.5
,
0.5
],
[
0.6
,
0.6
]],
...
...
@@ -713,6 +724,59 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
test_keypoints
=
True
)
def
testRunRandomHorizontalFlipWithKeypointDepth
(
self
):
def
graph_fn
():
preprocess_options
=
[(
preprocessor
.
random_horizontal_flip
,
{})]
image_height
=
3
image_width
=
3
images
=
tf
.
random_uniform
([
1
,
image_height
,
image_width
,
3
])
boxes
=
self
.
createTestBoxes
()
masks
=
self
.
createTestMasks
()
keypoints
,
keypoint_visibilities
=
self
.
createTestKeypoints
()
keypoint_depths
,
keypoint_depth_weights
=
self
.
createTestKeypointDepths
()
keypoint_flip_permutation
=
self
.
createKeypointFlipPermutation
()
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
images
,
fields
.
InputDataFields
.
groundtruth_boxes
:
boxes
,
fields
.
InputDataFields
.
groundtruth_instance_masks
:
masks
,
fields
.
InputDataFields
.
groundtruth_keypoints
:
keypoints
,
fields
.
InputDataFields
.
groundtruth_keypoint_visibilities
:
keypoint_visibilities
,
fields
.
InputDataFields
.
groundtruth_keypoint_depths
:
keypoint_depths
,
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
:
keypoint_depth_weights
,
}
preprocess_options
=
[(
preprocessor
.
random_horizontal_flip
,
{
'keypoint_flip_permutation'
:
keypoint_flip_permutation
,
'probability'
:
1.0
})]
preprocessor_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_instance_masks
=
True
,
include_keypoints
=
True
,
include_keypoint_visibilities
=
True
,
include_dense_pose
=
False
,
include_keypoint_depths
=
True
)
tensor_dict
=
preprocessor
.
preprocess
(
tensor_dict
,
preprocess_options
,
func_arg_map
=
preprocessor_arg_map
)
keypoint_depths
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
]
keypoint_depth_weights
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
]
output_tensors
=
[
keypoint_depths
,
keypoint_depth_weights
]
return
output_tensors
output_tensors
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_keypoint_depths
=
[[
1.0
,
0.8
,
0.9
],
[
0.7
,
0.5
,
0.6
]]
expected_keypoint_depth_weights
=
[[
0.5
,
0.7
,
0.6
],
[
0.8
,
1.0
,
0.9
]]
self
.
assertAllClose
(
expected_keypoint_depths
,
output_tensors
[
0
])
self
.
assertAllClose
(
expected_keypoint_depth_weights
,
output_tensors
[
1
])
def
testRandomVerticalFlip
(
self
):
def
graph_fn
():
...
...
research/object_detection/core/standard_fields.py
View file @
44e7092c
...
...
@@ -67,6 +67,9 @@ class InputDataFields(object):
groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints.
groundtruth_keypoint_depths: Relative depth of the keypoints.
groundtruth_keypoint_depth_weights: Weights of the relative depth of the
keypoints.
groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
groundtruth_label_weights: groundtruth label weights.
...
...
@@ -122,6 +125,8 @@ class InputDataFields(object):
groundtruth_instance_boundaries
=
'groundtruth_instance_boundaries'
groundtruth_instance_classes
=
'groundtruth_instance_classes'
groundtruth_keypoints
=
'groundtruth_keypoints'
groundtruth_keypoint_depths
=
'groundtruth_keypoint_depths'
groundtruth_keypoint_depth_weights
=
'groundtruth_keypoint_depth_weights'
groundtruth_keypoint_visibilities
=
'groundtruth_keypoint_visibilities'
groundtruth_keypoint_weights
=
'groundtruth_keypoint_weights'
groundtruth_label_weights
=
'groundtruth_label_weights'
...
...
@@ -162,6 +167,7 @@ class DetectionResultFields(object):
detection_boundaries: contains an object boundary for each detection box.
detection_keypoints: contains detection keypoints for each detection box.
detection_keypoint_scores: contains detection keypoint scores.
detection_keypoint_depths: contains detection keypoint depths.
num_detections: number of detections in the batch.
raw_detection_boxes: contains decoded detection boxes without Non-Max
suppression.
...
...
@@ -183,6 +189,7 @@ class DetectionResultFields(object):
detection_boundaries
=
'detection_boundaries'
detection_keypoints
=
'detection_keypoints'
detection_keypoint_scores
=
'detection_keypoint_scores'
detection_keypoint_depths
=
'detection_keypoint_depths'
detection_embeddings
=
'detection_embeddings'
detection_offsets
=
'detection_temporal_offsets'
num_detections
=
'num_detections'
...
...
@@ -205,6 +212,8 @@ class BoxListFields(object):
keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box.
keypoint_depths: keypoint depths per bounding box.
keypoint_depth_weights: keypoint depth weights per bounding box.
densepose_num_points: number of DensePose points per bounding box.
densepose_part_ids: DensePose part ids per bounding box.
densepose_surface_coords: DensePose surface coordinates per bounding box.
...
...
@@ -223,6 +232,8 @@ class BoxListFields(object):
keypoints
=
'keypoints'
keypoint_visibilities
=
'keypoint_visibilities'
keypoint_heatmaps
=
'keypoint_heatmaps'
keypoint_depths
=
'keypoint_depths'
keypoint_depth_weights
=
'keypoint_depth_weights'
densepose_num_points
=
'densepose_num_points'
densepose_part_ids
=
'densepose_part_ids'
densepose_surface_coords
=
'densepose_surface_coords'
...
...
research/object_detection/data_decoders/tf_example_decoder.py
View file @
44e7092c
...
...
@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_context_features
=
False
,
expand_hierarchy_labels
=
False
,
load_dense_pose
=
False
,
load_track_id
=
False
):
load_track_id
=
False
,
load_keypoint_depth_features
=
False
):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
...
...
@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations.
load_track_id: Whether to load tracking annotations.
load_keypoint_depth_features: Whether to load the keypoint depth features
including keypoint relative depths and weights. If this field is set to
True but no keypoint depth features are in the input tf.Example, then
default values will be populated.
Raises:
ValueError: If `instance_mask_type` option is not one of
...
...
@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
ValueError: If `expand_labels_hierarchy` is True, but the
`label_map_proto_file` is not provided.
"""
# TODO(rathodv): delete unused `use_display_name` argument once we change
# other decoders to handle label maps similarly.
del
use_display_name
...
...
@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder
.
ItemHandlerCallback
(
[
'image/object/keypoint/x'
,
'image/object/keypoint/visibility'
],
self
.
_reshape_keypoint_visibilities
))
if
load_keypoint_depth_features
:
self
.
keys_to_features
[
'image/object/keypoint/z'
]
=
(
tf
.
VarLenFeature
(
tf
.
float32
))
self
.
keys_to_features
[
'image/object/keypoint/z/weights'
]
=
(
tf
.
VarLenFeature
(
tf
.
float32
))
self
.
items_to_handlers
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
]
=
(
slim_example_decoder
.
ItemHandlerCallback
(
[
'image/object/keypoint/x'
,
'image/object/keypoint/z'
],
self
.
_reshape_keypoint_depths
))
self
.
items_to_handlers
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
]
=
(
slim_example_decoder
.
ItemHandlerCallback
(
[
'image/object/keypoint/x'
,
'image/object/keypoint/z/weights'
],
self
.
_reshape_keypoint_depth_weights
))
if
load_instance_masks
:
if
instance_mask_type
in
(
input_reader_pb2
.
DEFAULT
,
input_reader_pb2
.
NUMERICAL_MASKS
):
...
...
@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keypoints
=
tf
.
reshape
(
keypoints
,
[
-
1
,
self
.
_num_keypoints
,
2
])
return
keypoints
def
_reshape_keypoint_depths
(
self
,
keys_to_tensors
):
"""Reshape keypoint depths.
The keypoint depths are reshaped to [num_instances, num_keypoints]. The
keypoint depth tensor is expected to have the same shape as the keypoint x
(or y) tensors. If not (usually because the example does not have the depth
groundtruth), then default depth values (zero) are provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depths.
"""
x
=
keys_to_tensors
[
'image/object/keypoint/x'
]
z
=
keys_to_tensors
[
'image/object/keypoint/z'
]
if
isinstance
(
z
,
tf
.
SparseTensor
):
z
=
tf
.
sparse_tensor_to_dense
(
z
)
if
isinstance
(
x
,
tf
.
SparseTensor
):
x
=
tf
.
sparse_tensor_to_dense
(
x
)
default_z
=
tf
.
zeros_like
(
x
)
# Use keypoint depth groundtruth if provided, otherwise use the default
# depth value.
z
=
tf
.
cond
(
tf
.
equal
(
tf
.
size
(
x
),
tf
.
size
(
z
)),
true_fn
=
lambda
:
z
,
false_fn
=
lambda
:
default_z
)
z
=
tf
.
reshape
(
z
,
[
-
1
,
self
.
_num_keypoints
])
return
z
def
_reshape_keypoint_depth_weights
(
self
,
keys_to_tensors
):
"""Reshape keypoint depth weights.
The keypoint depth weights are reshaped to [num_instances, num_keypoints].
The keypoint depth weights tensor is expected to have the same shape as the
keypoint x (or y) tensors. If not (usually because the example does not have
the depth weights groundtruth), then default weight values (zero) are
provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z/weights'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depth weights.
"""
x
=
keys_to_tensors
[
'image/object/keypoint/x'
]
z
=
keys_to_tensors
[
'image/object/keypoint/z/weights'
]
if
isinstance
(
z
,
tf
.
SparseTensor
):
z
=
tf
.
sparse_tensor_to_dense
(
z
)
if
isinstance
(
x
,
tf
.
SparseTensor
):
x
=
tf
.
sparse_tensor_to_dense
(
x
)
default_z
=
tf
.
zeros_like
(
x
)
# Use keypoint depth weights if provided, otherwise use the default
# values.
z
=
tf
.
cond
(
tf
.
equal
(
tf
.
size
(
x
),
tf
.
size
(
z
)),
true_fn
=
lambda
:
z
,
false_fn
=
lambda
:
default_z
)
z
=
tf
.
reshape
(
z
,
[
-
1
,
self
.
_num_keypoints
])
return
z
def
_reshape_keypoint_visibilities
(
self
,
keys_to_tensors
):
"""Reshape keypoint visibilities.
...
...
research/object_detection/data_decoders/tf_example_decoder_test.py
View file @
44e7092c
...
...
@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase):
self
.
assertAllEqual
(
expected_boxes
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
])
def
testDecodeKeypointDepth
(
self
):
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
4
,
5
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
,
_
=
self
.
_create_encoded_and_decoded_data
(
image_tensor
,
'jpeg'
)
bbox_ymins
=
[
0.0
,
4.0
]
bbox_xmins
=
[
1.0
,
5.0
]
bbox_ymaxs
=
[
2.0
,
6.0
]
bbox_xmaxs
=
[
3.0
,
7.0
]
keypoint_ys
=
[
0.0
,
1.0
,
2.0
,
3.0
,
4.0
,
5.0
]
keypoint_xs
=
[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
]
keypoint_visibility
=
[
1
,
2
,
0
,
1
,
0
,
2
]
keypoint_depths
=
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
]
keypoint_depth_weights
=
[
1.0
,
0.9
,
0.8
,
0.7
,
0.6
,
0.5
]
def
graph_fn
():
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
dataset_util
.
bytes_feature
(
encoded_jpeg
),
'image/format'
:
dataset_util
.
bytes_feature
(
six
.
b
(
'jpeg'
)),
'image/object/bbox/ymin'
:
dataset_util
.
float_list_feature
(
bbox_ymins
),
'image/object/bbox/xmin'
:
dataset_util
.
float_list_feature
(
bbox_xmins
),
'image/object/bbox/ymax'
:
dataset_util
.
float_list_feature
(
bbox_ymaxs
),
'image/object/bbox/xmax'
:
dataset_util
.
float_list_feature
(
bbox_xmaxs
),
'image/object/keypoint/y'
:
dataset_util
.
float_list_feature
(
keypoint_ys
),
'image/object/keypoint/x'
:
dataset_util
.
float_list_feature
(
keypoint_xs
),
'image/object/keypoint/z'
:
dataset_util
.
float_list_feature
(
keypoint_depths
),
'image/object/keypoint/z/weights'
:
dataset_util
.
float_list_feature
(
keypoint_depth_weights
),
'image/object/keypoint/visibility'
:
dataset_util
.
int64_list_feature
(
keypoint_visibility
),
})).
SerializeToString
()
example_decoder
=
tf_example_decoder
.
TfExampleDecoder
(
num_keypoints
=
3
,
load_keypoint_depth_features
=
True
)
output
=
example_decoder
.
decode
(
tf
.
convert_to_tensor
(
example
))
self
.
assertAllEqual
(
(
output
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
].
get_shape
(
).
as_list
()),
[
2
,
3
])
self
.
assertAllEqual
(
(
output
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
]
.
get_shape
().
as_list
()),
[
2
,
3
])
return
output
tensor_dict
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_keypoint_depths
=
[[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]]
self
.
assertAllClose
(
expected_keypoint_depths
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
])
expected_keypoint_depth_weights
=
[[
1.0
,
0.9
,
0.8
],
[
0.7
,
0.6
,
0.5
]]
self
.
assertAllClose
(
expected_keypoint_depth_weights
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
])
def
testDecodeKeypointDepthNoDepth
(
self
):
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
4
,
5
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
,
_
=
self
.
_create_encoded_and_decoded_data
(
image_tensor
,
'jpeg'
)
bbox_ymins
=
[
0.0
,
4.0
]
bbox_xmins
=
[
1.0
,
5.0
]
bbox_ymaxs
=
[
2.0
,
6.0
]
bbox_xmaxs
=
[
3.0
,
7.0
]
keypoint_ys
=
[
0.0
,
1.0
,
2.0
,
3.0
,
4.0
,
5.0
]
keypoint_xs
=
[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
]
keypoint_visibility
=
[
1
,
2
,
0
,
1
,
0
,
2
]
def
graph_fn
():
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
dataset_util
.
bytes_feature
(
encoded_jpeg
),
'image/format'
:
dataset_util
.
bytes_feature
(
six
.
b
(
'jpeg'
)),
'image/object/bbox/ymin'
:
dataset_util
.
float_list_feature
(
bbox_ymins
),
'image/object/bbox/xmin'
:
dataset_util
.
float_list_feature
(
bbox_xmins
),
'image/object/bbox/ymax'
:
dataset_util
.
float_list_feature
(
bbox_ymaxs
),
'image/object/bbox/xmax'
:
dataset_util
.
float_list_feature
(
bbox_xmaxs
),
'image/object/keypoint/y'
:
dataset_util
.
float_list_feature
(
keypoint_ys
),
'image/object/keypoint/x'
:
dataset_util
.
float_list_feature
(
keypoint_xs
),
'image/object/keypoint/visibility'
:
dataset_util
.
int64_list_feature
(
keypoint_visibility
),
})).
SerializeToString
()
example_decoder
=
tf_example_decoder
.
TfExampleDecoder
(
num_keypoints
=
3
,
load_keypoint_depth_features
=
True
)
output
=
example_decoder
.
decode
(
tf
.
convert_to_tensor
(
example
))
return
output
tensor_dict
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_keypoints_depth_default
=
[[
0.0
,
0.0
,
0.0
],
[
0.0
,
0.0
,
0.0
]]
self
.
assertAllClose
(
expected_keypoints_depth_default
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depths
])
self
.
assertAllClose
(
expected_keypoints_depth_default
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
])
def
testDecodeKeypoint
(
self
):
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
4
,
5
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
,
_
=
self
.
_create_encoded_and_decoded_data
(
...
...
research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py
View file @
44e7092c
...
...
@@ -56,8 +56,7 @@ class FakeModel(model.DetectionModel):
value
=
conv_weight_scalar
))
def
preprocess
(
self
,
inputs
):
true_image_shapes
=
[]
# Doesn't matter for the fake model.
return
tf
.
identity
(
inputs
),
true_image_shapes
return
tf
.
identity
(
inputs
),
exporter_lib_v2
.
get_true_shapes
(
inputs
)
def
predict
(
self
,
preprocessed_inputs
,
true_image_shapes
):
return
{
'image'
:
self
.
_conv
(
preprocessed_inputs
)}
...
...
research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py
View file @
44e7092c
...
...
@@ -54,8 +54,7 @@ class FakeModel(model.DetectionModel):
value
=
conv_weight_scalar
))
def
preprocess
(
self
,
inputs
):
true_image_shapes
=
[]
# Doesn't matter for the fake model.
return
tf
.
identity
(
inputs
),
true_image_shapes
return
tf
.
identity
(
inputs
),
exporter_lib_v2
.
get_true_shapes
(
inputs
)
def
predict
(
self
,
preprocessed_inputs
,
true_image_shapes
):
return
{
'image'
:
self
.
_conv
(
preprocessed_inputs
)}
...
...
research/object_detection/exporter_lib_tf2_test.py
View file @
44e7092c
...
...
@@ -51,8 +51,7 @@ class FakeModel(model.DetectionModel):
value
=
conv_weight_scalar
))
def
preprocess
(
self
,
inputs
):
true_image_shapes
=
[]
# Doesn't matter for the fake model.
return
tf
.
identity
(
inputs
),
true_image_shapes
return
tf
.
identity
(
inputs
),
exporter_lib_v2
.
get_true_shapes
(
inputs
)
def
predict
(
self
,
preprocessed_inputs
,
true_image_shapes
,
**
side_inputs
):
return_dict
=
{
'image'
:
self
.
_conv
(
preprocessed_inputs
)}
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment