Unverified Commit 8518d053 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Open source MnasFPN and minor fixes to OD API (#8484)

310447280  by lzc:

    Internal change

310420845  by Zhichao Lu:

    Open source the internal Context RCNN code.

--
310362339  by Zhichao Lu:

    Internal change

310259448  by lzc:

    Update required TF version for OD API.

--
310252159  by Zhichao Lu:

    Port patch_ops_test to TF1/TF2 as TPUs.

--
310247180  by Zhichao Lu:

    Ignore keypoint heatmap loss in the regions/bounding boxes with target keypoint
    class but no valid keypoint annotations.

--
310178294  by Zhichao Lu:

    Opensource MnasFPN
    https://arxiv.org/abs/1912.01106

--
310094222  by lzc:

    Internal changes.

--
310085250  by lzc:

    Internal Change.

--
310016447  by huizhongc:

    Remove unrecognized classes from labeled_classes.

--
310009470  by rathodv:

    Mark batcher.py as TF1 only.

--
310001984  by rathodv:

    Update core/preprocessor.py to be compatible with TF1/TF2..

--
309455035  by Zhichao Lu:

    Makes the freezable_batch_norm_test run w/ v2 behavior.

    The main change is in v2 updates will happen right away when running batchnorm in training mode. So, we need to restore the weights between batchnorm calls to make sure the numerical checks all start from the same place.

--
309425881  by Zhichao Lu:

    Make TF1/TF2 optimizer builder tests explicit.

--
309408646  by Zhichao Lu:

    Make dataset builder tests TF1 and TF2 compatible.

--
309246305  by Zhichao Lu:

    Added the functionality of combining the person keypoints and object detection
    annotations in the binary that converts the COCO raw data to TfRecord.

--
309125076  by Zhichao Lu:

    Convert target_assigner_utils to TF1/TF2.

--
308966359  by huizhongc:

    Support SSD training with partially labeled groundtruth.

--
308937159  by rathodv:

    Update core/target_assigner.py to be compatible with TF1/TF2.

--
308774302  by Zhichao Lu:

    Internal

--
308732860  by rathodv:

    Make core/prefetcher.py  compatible with TF1 only.

--
308726984  by rathodv:

    Update core/multiclass_nms_test.py to be TF1/TF2 compatible.

--
308714718  by rathodv:

    Update core/region_similarity_calculator_test.py to be TF1/TF2 compatible.

--
308707960  by rathodv:

    Update core/minibatch_sampler_test.py to be TF1/TF2 compatible.

--
308700595  by rathodv:

    Update core/losses_test.py to be TF1/TF2 compatible and remove losses_test_v2.py

--
308361472  by rathodv:

    Update core/matcher_test.py to be TF1/TF2 compatible.

--
308335846  by Zhichao Lu:

    Updated the COCO evaluation logics and populated the groundturth area
    information through. This change matches the groundtruth format expected by the
    COCO keypoint evaluation.

--
308256924  by rathodv:

    Update core/keypoints_ops_test.py to be TF1/TF2 compatible.

--
308256826  by rathodv:

    Update class_agnostic_nms_test.py to be TF1/TF2 compatible.

--
308256112  by rathodv:

    Update box_list_ops_test.py to be TF1/TF2 compatible.

--
308159360  by Zhichao Lu:

    Internal change

308145008  by Zhichao Lu:

    Added 'image/class/confidence' field in the TFExample decoder.

--
307651875  by rathodv:

    Refactor core/box_list.py to support TF1/TF2.

--
307651798  by rathodv:

    Modify box_coder.py base class to work with with TF1/TF2

--
307651652  by rathodv:

    Refactor core/balanced_positive_negative_sampler.py to support TF1/TF2.

--
307651571  by rathodv:

    Modify BoxCoders tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651480  by rathodv:

    Modify Matcher tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651409  by rathodv:

    Modify AnchorGenerator tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651314  by rathodv:

    Refactor model_builder to support TF1 or TF2 models based on TensorFlow version.

--
307092053  by Zhichao Lu:

    Use manager to save checkpoint.

--
307071352  by ronnyvotel:

    Fixing keypoint visibilities. Now by default, the visibility is marked True if the keypoint is labeled (regardless of whether it is visible or not).
    Also, if visibilities are not present in the dataset, they will be created based on whether the keypoint coordinates are finite (vis = True) or NaN (vis = False).

--
307069557  by Zhichao Lu:

    Internal change to add few fields related to postprocessing parameters in
    center_net.proto and populate those parameters to the keypoint postprocessing
    functions.

--
307012091  by Zhichao Lu:

    Make Adam Optimizer's epsilon proto configurable.

    Potential issue: tf.compat.v1's AdamOptimizer has a default epsilon on 1e-08 ([doc-link](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/AdamOptimizer))  whereas tf.keras's AdamOptimizer has default epsilon 1e-07 ([doc-link](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam))

--
306858598  by Zhichao Lu:

    Internal changes to update the CenterNet model:
    1) Modified eval job loss computation to avoid averaging over batches with zero loss.
    2) Updated CenterNet keypoint heatmap target assigner to apply box size to heatmap Guassian standard deviation.
    3) Updated the CenterNet meta arch keypoint losses computation to apply weights outside of loss function.

--
306731223  by jonathanhuang:

    Internal change.

--
306549183  by rathodv:

    Internal Update.

--
306542930  by rathodv:

    Internal Update

--
306322697  by rathodv:

    Internal.

--
305345036  by Zhichao Lu:

    Adding COCO Camera Traps Json to tf.Example beam code

--
304104869  by lzc:

    Internal changes.

--
304068971  by jonathanhuang:

    Internal change.

--
304050469  by Zhichao Lu:

    Internal change.

--
303880642  by huizhongc:

    Support parsing partially labeled groundtruth.

--
303841743  by Zhichao Lu:

    Deprecate nms_on_host in SSDMetaArch.

--
303803204  by rathodv:

    Internal change.

--
303793895  by jonathanhuang:

    Internal change.

--
303467631  by rathodv:

    Py3 update for detection inference test.

--
303444542  by rathodv:

    Py3 update to metrics module

--
303421960  by rathodv:

    Update json_utils to python3.

--
302787583  by ronnyvotel:

    Coco results generator for submission to the coco test server.

--
302719091  by Zhichao Lu:

    Internal change to add the ResNet50 image feature extractor for CenterNet model.

--
302116230  by Zhichao Lu:

    Added the functions to overlay the heatmaps with images in visualization util
    library.

--
301888316  by Zhichao Lu:

    Fix checkpoint_filepath not defined error.

--
301840312  by ronnyvotel:

    Adding keypoint_scores to visualizations.

--
301683475  by ronnyvotel:

    Introducing the ability to preprocess `keypoint_visibilities`.

    Some data augmentation ops such as random crop can filter instances and keypoints. It's important to also filter keypoint visibilities, so that the groundtruth tensors are always in alignment.

--
301532344  by Zhichao Lu:

    Don't use tf.divide since "Quantization not yet supported for op: DIV"

--
301480348  by ronnyvotel:

    Introducing keypoint evaluation into model lib v2.
    Also, making some fixes to coco keypoint evaluation.

--
301454018  by Zhichao Lu:

    Added the image summary to visualize the train/eval input images and eval's
    prediction/groundtruth side-by-side image.

--
301317527  by Zhichao Lu:

    Updated the random_absolute_pad_image function in the preprocessor library to
    support the keypoints argument.

--
301300324  by Zhichao Lu:

    Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow.

--
301297115  by ronnyvotel:

    Utility function for setting keypoint visibilities based on keypoint coordinates.

--
301248885  by Zhichao Lu:

    Allow MultiworkerMirroredStrategy(MWMS) use by adding checkpoint handling with temporary directories in model_lib_v2. Added missing WeakKeyDictionary cfer_fn_cache field in CollectiveAllReduceStrategyExtended.

--
301224559  by Zhichao Lu:

    ...1) Fixes model_lib to also use keypoints while preparing model groundtruth.
    ...2) Tests model_lib with newly added keypoint metrics config.

--
300836556  by Zhichao Lu:

    Internal changes to add keypoint estimation parameters in CenterNet proto.

--
300795208  by Zhichao Lu:

    Updated the eval_util library to populate the keypoint groundtruth to
    eval_dict.

--
299474766  by Zhichao Lu:

    ...Modifies eval_util to create Keypoint Evaluator objects when configured in eval config.

--
299453920  by Zhichao Lu:

    Add swish activation as a hyperperams option.

--
299240093  by ronnyvotel:

    Keypoint postprocessing for CenterNetMetaArch.

--
299176395  by Zhichao Lu:

    Internal change.

--
299135608  by Zhichao Lu:

    Internal changes to refactor the CenterNet model in preparation for keypoint estimation tasks.

--
298915482  by Zhichao Lu:

    Make dataset_builder aware of input_context for distributed training.

--
298713595  by Zhichao Lu:

    Handling data with negative size boxes.

--
298695964  by Zhichao Lu:

    Expose change_coordinate_frame as a config parameter; fix multiclass_scores optional field.

--
298492150  by Zhichao Lu:

    Rename optimizer_builder_test_v2.py -> optimizer_builder_v2_test.py

--
298476471  by Zhichao Lu:

    Internal changes to support CenterNet keypoint estimation.

--
298365851  by ronnyvotel:

    Fixing a bug where groundtruth_keypoint_weights were being padded with a dynamic dimension.

--
297843700  by Zhichao Lu:

    Internal change.

--
297706988  by lzc:

    Internal change.

--
297705287  by ronnyvotel:

    Creating the "snapping" behavior in CenterNet, where regressed keypoints are refined with updated candidate keypoints from a heatmap.

--
297700447  by Zhichao Lu:

    Improve checkpoint checking logic with TF2 loop.

--
297686094  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297670468  by lzc:

    Internal change.

--
297241327  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297205959  by Zhichao Lu:

    Internal changes to support refactored the centernet object detection target assigner into a separate library.

--
297143806  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297129625  by Zhichao Lu:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297117070  by Zhichao Lu:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297030190  by Zhichao Lu:

    Add configuration options for visualizing keypoint edges

--
296359649  by Zhichao Lu:

    Support DepthwiseConv2dNative (of separable conv) in weight equalization loss.

--
296290582  by Zhichao Lu:

    Internal change.

--
296093857  by Zhichao Lu:

    Internal changes to add general target assigner utilities.

--
295975116  by Zhichao Lu:

    Fix visualize_boxes_and_labels_on_image_array to show max_boxes_to_draw correctly.

--
295819711  by Zhichao Lu:

    Adds a flag to visualize_boxes_and_labels_on_image_array to skip the drawing of axis aligned bounding boxes.

--
295811929  by Zhichao Lu:

    Keypoint support in random_square_crop_by_scale.

--
295788458  by rathodv:

    Remove unused checkpoint to reduce repo size on github

--
295787184  by Zhichao Lu:

    Enable visualization of edges between keypoints

--
295763508  by Zhichao Lu:

    [Context RCNN] Add an option to enable / disable cropping feature in the post
    process step in the meta archtecture.

--
295605344  by Zhichao Lu:

    internal change.

--
294926050  by ronnyvotel:

    Adding per-keypoint groundtruth weights. These weights are intended to be used as multipliers in a keypoint loss function.

    Groundtruth keypoint weights are constructed as follows:
    - Initialize the weight for each keypoint type based on user-specified weights in the input_reader proto
    - Mask out (i.e. make zero) all keypoint weights that are not visible.

--
294829061  by lzc:

    Internal change.

--
294566503  by Zhichao Lu:

    Changed internal CenterNet Model configuration.

--
294346662  by ronnyvotel:

    Using NaN values in keypoint coordinates that are not visible.

--
294333339  by Zhichao Lu:

    Change experimetna_distribute_dataset -> experimental_distribute_dataset_from_function

--
293928752  by Zhichao Lu:

    Internal change

--
293909384  by Zhichao Lu:

    Add capabilities to train 1024x1024 CenterNet models.

--
293637554  by ronnyvotel:

    Adding keypoint visibilities to TfExampleDecoder.

--
293501558  by lzc:

    Internal change.

--
293252851  by Zhichao Lu:

    Change tf.gfile.GFile to tf.io.gfile.GFile.

--
292730217  by Zhichao Lu:

    Internal change.

--
292456563  by lzc:

    Internal changes.

--
292355612  by Zhichao Lu:

    Use tf.gather and tf.scatter_nd instead of matrix ops.

--
292245265  by rathodv:

    Internal

--
291989323  by richardmunoz:

    Refactor out building a DataDecoder from building a tf.data.Dataset.

--
291950147  by Zhichao Lu:

    Flip bounding boxes in arbitrary shaped tensors.

--
291401052  by huizhongc:

    Fix multiscale grid anchor generator to allow fully convolutional inference. When exporting model with identity_resizer as image_resizer, there is an incorrect box offset on the detection results. We add the anchor offset to address this problem.

--
291298871  by Zhichao Lu:

    Py3 compatibility changes.

--
290957957  by Zhichao Lu:

    Hourglass feature extractor for CenterNet.

--
290564372  by Zhichao Lu:

    Internal change.

--
290155278  by rathodv:

    Remove Dataset Explorer.

--
290155153  by Zhichao Lu:

    Internal change

--
290122054  by Zhichao Lu:

    Unify the format in the faster_rcnn.proto

--
290116084  by Zhichao Lu:

    Deprecate tensorflow.contrib.

--
290100672  by Zhichao Lu:

    Update MobilenetV3 SSD candidates

--
289926392  by Zhichao Lu:

    Internal change

--
289553440  by Zhichao Lu:

    [Object Detection API] Fix the comments about the dimension of the rpn_box_encodings from 4-D to 3-D.

--
288994128  by lzc:

    Internal changes.

--
288942194  by lzc:

    Internal change.

--
288746124  by Zhichao Lu:

    Configurable channel mean/std. dev in CenterNet feature extractors.

--
288552509  by rathodv:

    Internal.

--
288541285  by rathodv:

    Internal update.

--
288396396  by Zhichao Lu:

    Make object detection import contrib explicitly

--
288255791  by rathodv:

    Internal

--
288078600  by Zhichao Lu:

    Fix model_lib_v2 test

--
287952244  by rathodv:

    Internal

--
287921774  by Zhichao Lu:

    internal change

--
287906173  by Zhichao Lu:

    internal change

--
287889407  by jonathanhuang:

    PY3 compatibility

--
287889042  by rathodv:

    Internal

--
287876178  by Zhichao Lu:

    Internal change.

--
287770490  by Zhichao Lu:

    Add CenterNet proto and builder

--
287694213  by Zhichao Lu:

    Support for running multiple steps per tf.function call.

--
287377183  by jonathanhuang:

    PY3 compatibility

--
287371344  by rathodv:

    Support loading keypoint labels and ids.

--
287368213  by rathodv:

    Add protos supporting keypoint evaluation.

--
286673200  by rathodv:

    dataset_tools PY3 migration

--
286635106  by Zhichao Lu:

    Update code for upcoming tf.contrib removal

--
286479439  by Zhichao Lu:

    Internal change

--
286311711  by Zhichao Lu:

    Skeleton of context model within TFODAPI

--
286005546  by Zhichao Lu:

    Fix Faster-RCNN training when using keep_aspect_ratio_resizer with pad_to_max_dimension

--
285906400  by derekjchow:

    Internal change

--
285822795  by Zhichao Lu:

    Add CenterNet meta arch target assigners.

--
285447238  by Zhichao Lu:

    Internal changes.

--
285016927  by Zhichao Lu:

    Make _dummy_computation a tf.function. This fixes breakage caused by
    cl/284256438

--
284827274  by Zhichao Lu:

    Convert to python 3.

--
284645593  by rathodv:

    Internal change

--
284639893  by rathodv:

    Add missing documentation for keypoints in eval_util.py.

--
284323712  by Zhichao Lu:

    Internal changes.

--
284295290  by Zhichao Lu:

    Updating input config proto and dataset builder to include context fields

    Updating standard_fields and tf_example_decoder to include context features

--
284226821  by derekjchow:

    Update exporter.

--
284211030  by Zhichao Lu:

    API changes in CenterNet informed by the experiments with hourlgass network.

--
284190451  by Zhichao Lu:

    Add support for CenterNet losses in protos and builders.

--
284093961  by lzc:

    Internal changes.

--
284028174  by Zhichao Lu:

    Internal change

--
284014719  by derekjchow:

    Do not pad top_down feature maps unnecessarily.

--
284005765  by Zhichao Lu:

    Add new pad_to_multiple_resizer

--
283858233  by Zhichao Lu:

    Make target assigner work when under tf.function.

--
283836611  by Zhichao Lu:

    Make config getters more general.

--
283808990  by Zhichao Lu:

    Internal change

--
283754588  by Zhichao Lu:

    Internal changes.

--
282460301  by Zhichao Lu:

    Add ability to restore v2 style checkpoints.

--
281605842  by lzc:

    Add option to disable loss computation in OD API eval job.

--
280298212  by Zhichao Lu:

    Add backwards compatible change

--
280237857  by Zhichao Lu:

    internal change

--

PiperOrigin-RevId: 310447280
parent ac5fff19
......@@ -17,7 +17,6 @@
import functools
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.contrib import slim as contrib_slim
from object_detection.builders import post_processing_builder
from object_detection.core import anchor_generator
......@@ -34,7 +33,14 @@ from object_detection.utils import ops
from object_detection.utils import test_case
from object_detection.utils import test_utils
slim = contrib_slim
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import slim as contrib_slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
keras = tf.keras.layers
......@@ -54,7 +60,7 @@ class FakeSSDFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
def extract_features(self, preprocessed_inputs):
with tf.variable_scope('mock_model'):
features = slim.conv2d(
features = contrib_slim.conv2d(
inputs=preprocessed_inputs,
num_outputs=32,
kernel_size=1,
......
......@@ -82,17 +82,39 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
[num_boxes] containing 1-indexed groundtruth classes for the boxes.
InputDataFields.groundtruth_is_crowd (optional): integer numpy array of
shape [num_boxes] containing iscrowd flag for groundtruth boxes.
InputDataFields.groundtruth_area (optional): float numpy array of
shape [num_boxes] containing the area (in the original absolute
coordinates) of the annotated object.
InputDataFields.groundtruth_keypoints (optional): float numpy array of
keypoints with shape [num_boxes, num_keypoints, 2].
InputDataFields.groundtruth_keypoint_visibilities (optional): integer
numpy array of keypoint visibilities with shape [num_gt_boxes,
num_keypoints]. Integer is treated as an enum with 0=not labeled,
1=labeled but not visible and 2=labeled and visible.
"""
if image_id in self._image_ids:
tf.logging.warning('Ignoring ground truth with image id %s since it was '
'previously added', image_id)
return
# Drop optional fields if empty tensor.
groundtruth_is_crowd = groundtruth_dict.get(
standard_fields.InputDataFields.groundtruth_is_crowd)
# Drop groundtruth_is_crowd if empty tensor.
groundtruth_area = groundtruth_dict.get(
standard_fields.InputDataFields.groundtruth_area)
groundtruth_keypoints = groundtruth_dict.get(
standard_fields.InputDataFields.groundtruth_keypoints)
groundtruth_keypoint_visibilities = groundtruth_dict.get(
standard_fields.InputDataFields.groundtruth_keypoint_visibilities)
if groundtruth_is_crowd is not None and not groundtruth_is_crowd.shape[0]:
groundtruth_is_crowd = None
if groundtruth_area is not None and not groundtruth_area.shape[0]:
groundtruth_area = None
if groundtruth_keypoints is not None and not groundtruth_keypoints.shape[0]:
groundtruth_keypoints = None
if groundtruth_keypoint_visibilities is not None and not groundtruth_keypoint_visibilities.shape[
0]:
groundtruth_keypoint_visibilities = None
self._groundtruth_list.extend(
coco_tools.ExportSingleImageGroundtruthToCoco(
......@@ -103,7 +125,12 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
standard_fields.InputDataFields.groundtruth_boxes],
groundtruth_classes=groundtruth_dict[
standard_fields.InputDataFields.groundtruth_classes],
groundtruth_is_crowd=groundtruth_is_crowd))
groundtruth_is_crowd=groundtruth_is_crowd,
groundtruth_area=groundtruth_area,
groundtruth_keypoints=groundtruth_keypoints,
groundtruth_keypoint_visibilities=groundtruth_keypoint_visibilities)
)
self._annotation_id += groundtruth_dict[standard_fields.InputDataFields.
groundtruth_boxes].shape[0]
# Boolean to indicate whether a detection has been added for this image.
......@@ -127,7 +154,8 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
[num_boxes] containing detection scores for the boxes.
DetectionResultFields.detection_classes: integer numpy array of shape
[num_boxes] containing 1-indexed detection classes for the boxes.
DetectionResultFields.detection_keypoints (optional): float numpy array
of keypoints with shape [num_boxes, num_keypoints, 2].
Raises:
ValueError: If groundtruth for the image_id is not available.
"""
......@@ -139,19 +167,22 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
'previously added', image_id)
return
# Drop optional fields if empty tensor.
detection_keypoints = detections_dict.get(
standard_fields.DetectionResultFields.detection_keypoints)
if detection_keypoints is not None and not detection_keypoints.shape[0]:
detection_keypoints = None
self._detection_boxes_list.extend(
coco_tools.ExportSingleImageDetectionBoxesToCoco(
image_id=image_id,
category_id_set=self._category_id_set,
detection_boxes=detections_dict[standard_fields.
DetectionResultFields
.detection_boxes],
detection_scores=detections_dict[standard_fields.
DetectionResultFields.
detection_scores],
detection_classes=detections_dict[standard_fields.
DetectionResultFields.
detection_classes]))
detection_boxes=detections_dict[
standard_fields.DetectionResultFields.detection_boxes],
detection_scores=detections_dict[
standard_fields.DetectionResultFields.detection_scores],
detection_classes=detections_dict[
standard_fields.DetectionResultFields.detection_classes],
detection_keypoints=detection_keypoints))
self._image_ids[image_id] = True
def dump_detections_to_json_file(self, json_output_path):
......@@ -410,6 +441,460 @@ def _check_mask_type_and_value(array_name, masks):
array_name))
class CocoKeypointEvaluator(CocoDetectionEvaluator):
"""Class to evaluate COCO keypoint metrics."""
def __init__(self,
category_id,
category_keypoints,
class_text,
oks_sigmas=None):
"""Constructor.
Args:
category_id: An integer id uniquely identifying this category.
category_keypoints: A list specifying keypoint mappings, with items:
'id': (required) an integer id identifying the keypoint.
'name': (required) a string representing the keypoint name.
class_text: A string representing the category name for which keypoint
metrics are to be computed.
oks_sigmas: A dict of keypoint name to standard deviation values for OKS
metrics. If not provided, default value of 0.05 will be used.
"""
self._category_id = category_id
self._category_name = class_text
self._keypoint_ids = sorted(
[keypoint['id'] for keypoint in category_keypoints])
kpt_id_to_name = {kpt['id']: kpt['name'] for kpt in category_keypoints}
if oks_sigmas:
self._oks_sigmas = np.array([
oks_sigmas[kpt_id_to_name[idx]] for idx in self._keypoint_ids
])
else:
# Default all per-keypoint sigmas to 0.
self._oks_sigmas = np.full((len(self._keypoint_ids)), 0.05)
tf.logging.warning('No default keypoint OKS sigmas provided. Will use '
'0.05')
tf.logging.info('Using the following keypoint OKS sigmas: {}'.format(
self._oks_sigmas))
self._metrics = None
super(CocoKeypointEvaluator, self).__init__([{
'id': self._category_id,
'name': class_text
}])
def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
"""Adds groundtruth for a single image with keypoints.
If the image has already been added, a warning is logged, and groundtruth
is ignored.
Args:
image_id: A unique string/integer identifier for the image.
groundtruth_dict: A dictionary containing -
InputDataFields.groundtruth_boxes: float32 numpy array of shape
[num_boxes, 4] containing `num_boxes` groundtruth boxes of the format
[ymin, xmin, ymax, xmax] in absolute image coordinates.
InputDataFields.groundtruth_classes: integer numpy array of shape
[num_boxes] containing 1-indexed groundtruth classes for the boxes.
InputDataFields.groundtruth_is_crowd (optional): integer numpy array of
shape [num_boxes] containing iscrowd flag for groundtruth boxes.
InputDataFields.groundtruth_area (optional): float numpy array of
shape [num_boxes] containing the area (in the original absolute
coordinates) of the annotated object.
InputDataFields.groundtruth_keypoints: float numpy array of
keypoints with shape [num_boxes, num_keypoints, 2].
InputDataFields.groundtruth_keypoint_visibilities (optional): integer
numpy array of keypoint visibilities with shape [num_gt_boxes,
num_keypoints]. Integer is treated as an enum with 0=not labels,
1=labeled but not visible and 2=labeled and visible.
"""
# Keep only the groundtruth for our category and its keypoints.
groundtruth_classes = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_classes]
groundtruth_boxes = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_boxes]
groundtruth_keypoints = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_keypoints]
class_indices = [
idx for idx, gt_class_id in enumerate(groundtruth_classes)
if gt_class_id == self._category_id
]
filtered_groundtruth_classes = np.take(
groundtruth_classes, class_indices, axis=0)
filtered_groundtruth_boxes = np.take(
groundtruth_boxes, class_indices, axis=0)
filtered_groundtruth_keypoints = np.take(
groundtruth_keypoints, class_indices, axis=0)
filtered_groundtruth_keypoints = np.take(
filtered_groundtruth_keypoints, self._keypoint_ids, axis=1)
filtered_groundtruth_dict = {}
filtered_groundtruth_dict[
standard_fields.InputDataFields
.groundtruth_classes] = filtered_groundtruth_classes
filtered_groundtruth_dict[standard_fields.InputDataFields
.groundtruth_boxes] = filtered_groundtruth_boxes
filtered_groundtruth_dict[
standard_fields.InputDataFields
.groundtruth_keypoints] = filtered_groundtruth_keypoints
if (standard_fields.InputDataFields.groundtruth_is_crowd in
groundtruth_dict.keys()):
groundtruth_is_crowd = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_is_crowd]
filtered_groundtruth_is_crowd = np.take(groundtruth_is_crowd,
class_indices, 0)
filtered_groundtruth_dict[
standard_fields.InputDataFields
.groundtruth_is_crowd] = filtered_groundtruth_is_crowd
if (standard_fields.InputDataFields.groundtruth_area in
groundtruth_dict.keys()):
groundtruth_area = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_area]
filtered_groundtruth_area = np.take(groundtruth_area, class_indices, 0)
filtered_groundtruth_dict[
standard_fields.InputDataFields
.groundtruth_area] = filtered_groundtruth_area
if (standard_fields.InputDataFields.groundtruth_keypoint_visibilities in
groundtruth_dict.keys()):
groundtruth_keypoint_visibilities = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_keypoint_visibilities]
filtered_groundtruth_keypoint_visibilities = np.take(
groundtruth_keypoint_visibilities, class_indices, axis=0)
filtered_groundtruth_keypoint_visibilities = np.take(
filtered_groundtruth_keypoint_visibilities,
self._keypoint_ids,
axis=1)
filtered_groundtruth_dict[
standard_fields.InputDataFields.
groundtruth_keypoint_visibilities] = filtered_groundtruth_keypoint_visibilities
super(CocoKeypointEvaluator,
self).add_single_ground_truth_image_info(image_id,
filtered_groundtruth_dict)
def add_single_detected_image_info(self, image_id, detections_dict):
"""Adds detections for a single image and the specific category for which keypoints are evaluated.
If a detection has already been added for this image id, a warning is
logged, and the detection is skipped.
Args:
image_id: A unique string/integer identifier for the image.
detections_dict: A dictionary containing -
DetectionResultFields.detection_boxes: float32 numpy array of shape
[num_boxes, 4] containing `num_boxes` detection boxes of the format
[ymin, xmin, ymax, xmax] in absolute image coordinates.
DetectionResultFields.detection_scores: float32 numpy array of shape
[num_boxes] containing detection scores for the boxes.
DetectionResultFields.detection_classes: integer numpy array of shape
[num_boxes] containing 1-indexed detection classes for the boxes.
DetectionResultFields.detection_keypoints: float numpy array of
keypoints with shape [num_boxes, num_keypoints, 2].
Raises:
ValueError: If groundtruth for the image_id is not available.
"""
# Keep only the detections for our category and its keypoints.
detection_classes = detections_dict[
standard_fields.DetectionResultFields.detection_classes]
detection_boxes = detections_dict[
standard_fields.DetectionResultFields.detection_boxes]
detection_scores = detections_dict[
standard_fields.DetectionResultFields.detection_scores]
detection_keypoints = detections_dict[
standard_fields.DetectionResultFields.detection_keypoints]
class_indices = [
idx for idx, class_id in enumerate(detection_classes)
if class_id == self._category_id
]
filtered_detection_classes = np.take(
detection_classes, class_indices, axis=0)
filtered_detection_boxes = np.take(detection_boxes, class_indices, axis=0)
filtered_detection_scores = np.take(detection_scores, class_indices, axis=0)
filtered_detection_keypoints = np.take(
detection_keypoints, class_indices, axis=0)
filtered_detection_keypoints = np.take(
filtered_detection_keypoints, self._keypoint_ids, axis=1)
filtered_detections_dict = {}
filtered_detections_dict[standard_fields.DetectionResultFields
.detection_classes] = filtered_detection_classes
filtered_detections_dict[standard_fields.DetectionResultFields
.detection_boxes] = filtered_detection_boxes
filtered_detections_dict[standard_fields.DetectionResultFields
.detection_scores] = filtered_detection_scores
filtered_detections_dict[standard_fields.DetectionResultFields.
detection_keypoints] = filtered_detection_keypoints
super(CocoKeypointEvaluator,
self).add_single_detected_image_info(image_id,
filtered_detections_dict)
def evaluate(self):
"""Evaluates the keypoints and returns a dictionary of coco metrics.
Returns:
A dictionary holding -
1. summary_metrics:
'Keypoints_Precision/mAP': mean average precision over classes
averaged over OKS thresholds ranging from .5 to .95 with .05
increments.
'Keypoints_Precision/mAP@.50IOU': mean average precision at 50% OKS
'Keypoints_Precision/mAP@.75IOU': mean average precision at 75% OKS
'Keypoints_Precision/mAP (medium)': mean average precision for medium
sized objects (32^2 pixels < area < 96^2 pixels).
'Keypoints_Precision/mAP (large)': mean average precision for large
objects (96^2 pixels < area < 10000^2 pixels).
'Keypoints_Recall/AR@1': average recall with 1 detection.
'Keypoints_Recall/AR@10': average recall with 10 detections.
'Keypoints_Recall/AR@100': average recall with 100 detections.
'Keypoints_Recall/AR@100 (medium)': average recall for medium objects with
100.
'Keypoints_Recall/AR@100 (large)': average recall for large objects with
100 detections.
"""
tf.logging.info('Performing evaluation on %d images.', len(self._image_ids))
groundtruth_dict = {
'annotations': self._groundtruth_list,
'images': [{'id': image_id} for image_id in self._image_ids],
'categories': self._categories
}
coco_wrapped_groundtruth = coco_tools.COCOWrapper(
groundtruth_dict, detection_type='bbox')
coco_wrapped_detections = coco_wrapped_groundtruth.LoadAnnotations(
self._detection_boxes_list)
keypoint_evaluator = coco_tools.COCOEvalWrapper(
coco_wrapped_groundtruth,
coco_wrapped_detections,
agnostic_mode=False,
iou_type='keypoints',
oks_sigmas=self._oks_sigmas)
keypoint_metrics, _ = keypoint_evaluator.ComputeMetrics(
include_metrics_per_category=False, all_metrics_per_category=False)
keypoint_metrics = {
'Keypoints_' + key: value
for key, value in iter(keypoint_metrics.items())
}
return keypoint_metrics
def add_eval_dict(self, eval_dict):
"""Observes an evaluation result dict for a single example.
When executing eagerly, once all observations have been observed by this
method you can use `.evaluate()` to get the final metrics.
When using `tf.estimator.Estimator` for evaluation this function is used by
`get_estimator_eval_metric_ops()` to construct the metric update op.
Args:
eval_dict: A dictionary that holds tensors for evaluating an object
detection model, returned from
eval_util.result_dict_for_single_example().
Returns:
None when executing eagerly, or an update_op that can be used to update
the eval metrics in `tf.estimator.EstimatorSpec`.
"""
def update_op(
image_id_batched,
groundtruth_boxes_batched,
groundtruth_classes_batched,
groundtruth_is_crowd_batched,
groundtruth_area_batched,
groundtruth_keypoints_batched,
groundtruth_keypoint_visibilities_batched,
num_gt_boxes_per_image,
detection_boxes_batched,
detection_scores_batched,
detection_classes_batched,
detection_keypoints_batched,
num_det_boxes_per_image,
is_annotated_batched):
"""Update operation for adding batch of images to Coco evaluator."""
for (image_id, gt_box, gt_class, gt_is_crowd, gt_area, gt_keyp,
gt_keyp_vis, num_gt_box, det_box, det_score, det_class, det_keyp,
num_det_box, is_annotated) in zip(
image_id_batched, groundtruth_boxes_batched,
groundtruth_classes_batched, groundtruth_is_crowd_batched,
groundtruth_area_batched, groundtruth_keypoints_batched,
groundtruth_keypoint_visibilities_batched,
num_gt_boxes_per_image, detection_boxes_batched,
detection_scores_batched, detection_classes_batched,
detection_keypoints_batched, num_det_boxes_per_image,
is_annotated_batched):
if is_annotated:
self.add_single_ground_truth_image_info(
image_id, {
'groundtruth_boxes': gt_box[:num_gt_box],
'groundtruth_classes': gt_class[:num_gt_box],
'groundtruth_is_crowd': gt_is_crowd[:num_gt_box],
'groundtruth_area': gt_area[:num_gt_box],
'groundtruth_keypoints': gt_keyp[:num_gt_box],
'groundtruth_keypoint_visibilities': gt_keyp_vis[:num_gt_box]
})
self.add_single_detected_image_info(
image_id, {
'detection_boxes': det_box[:num_det_box],
'detection_scores': det_score[:num_det_box],
'detection_classes': det_class[:num_det_box],
'detection_keypoints': det_keyp[:num_det_box],
})
# Unpack items from the evaluation dictionary.
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
image_id = eval_dict[input_data_fields.key]
groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes]
groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes]
groundtruth_is_crowd = eval_dict.get(input_data_fields.groundtruth_is_crowd,
None)
groundtruth_area = eval_dict.get(input_data_fields.groundtruth_area, None)
groundtruth_keypoints = eval_dict[input_data_fields.groundtruth_keypoints]
groundtruth_keypoint_visibilities = eval_dict.get(
input_data_fields.groundtruth_keypoint_visibilities, None)
detection_boxes = eval_dict[detection_fields.detection_boxes]
detection_scores = eval_dict[detection_fields.detection_scores]
detection_classes = eval_dict[detection_fields.detection_classes]
detection_keypoints = eval_dict[detection_fields.detection_keypoints]
num_gt_boxes_per_image = eval_dict.get(
'num_groundtruth_boxes_per_image', None)
num_det_boxes_per_image = eval_dict.get('num_det_boxes_per_image', None)
is_annotated = eval_dict.get('is_annotated', None)
if groundtruth_is_crowd is None:
groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool)
if groundtruth_area is None:
groundtruth_area = tf.zeros_like(groundtruth_classes, dtype=tf.float32)
if not image_id.shape.as_list():
# Apply a batch dimension to all tensors.
image_id = tf.expand_dims(image_id, 0)
groundtruth_boxes = tf.expand_dims(groundtruth_boxes, 0)
groundtruth_classes = tf.expand_dims(groundtruth_classes, 0)
groundtruth_is_crowd = tf.expand_dims(groundtruth_is_crowd, 0)
groundtruth_area = tf.expand_dims(groundtruth_area, 0)
groundtruth_keypoints = tf.expand_dims(groundtruth_keypoints, 0)
detection_boxes = tf.expand_dims(detection_boxes, 0)
detection_scores = tf.expand_dims(detection_scores, 0)
detection_classes = tf.expand_dims(detection_classes, 0)
detection_keypoints = tf.expand_dims(detection_keypoints, 0)
if num_gt_boxes_per_image is None:
num_gt_boxes_per_image = tf.shape(groundtruth_boxes)[1:2]
else:
num_gt_boxes_per_image = tf.expand_dims(num_gt_boxes_per_image, 0)
if num_det_boxes_per_image is None:
num_det_boxes_per_image = tf.shape(detection_boxes)[1:2]
else:
num_det_boxes_per_image = tf.expand_dims(num_det_boxes_per_image, 0)
if is_annotated is None:
is_annotated = tf.constant([True])
else:
is_annotated = tf.expand_dims(is_annotated, 0)
if groundtruth_keypoint_visibilities is None:
groundtruth_keypoint_visibilities = tf.fill([
tf.shape(groundtruth_boxes)[1],
tf.shape(groundtruth_keypoints)[2]
], tf.constant(2, dtype=tf.int32))
groundtruth_keypoint_visibilities = tf.expand_dims(
groundtruth_keypoint_visibilities, 0)
else:
if num_gt_boxes_per_image is None:
num_gt_boxes_per_image = tf.tile(
tf.shape(groundtruth_boxes)[1:2],
multiples=tf.shape(groundtruth_boxes)[0:1])
if num_det_boxes_per_image is None:
num_det_boxes_per_image = tf.tile(
tf.shape(detection_boxes)[1:2],
multiples=tf.shape(detection_boxes)[0:1])
if is_annotated is None:
is_annotated = tf.ones_like(image_id, dtype=tf.bool)
if groundtruth_keypoint_visibilities is None:
groundtruth_keypoint_visibilities = tf.fill([
tf.shape(groundtruth_keypoints)[1],
tf.shape(groundtruth_keypoints)[2]
], tf.constant(2, dtype=tf.int32))
groundtruth_keypoint_visibilities = tf.tile(
tf.expand_dims(groundtruth_keypoint_visibilities, 0),
multiples=[tf.shape(groundtruth_keypoints)[0], 1, 1])
return tf.py_func(update_op, [
image_id, groundtruth_boxes, groundtruth_classes, groundtruth_is_crowd,
groundtruth_area, groundtruth_keypoints,
groundtruth_keypoint_visibilities, num_gt_boxes_per_image,
detection_boxes, detection_scores, detection_classes,
detection_keypoints, num_det_boxes_per_image, is_annotated
], [])
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns a dictionary of eval metric ops.
Note that once value_op is called, the detections and groundtruth added via
update_op are cleared.
This function can take in groundtruth and detections for a batch of images,
or for a single image. For the latter case, the batch dimension for input
tensors need not be present.
Args:
eval_dict: A dictionary that holds tensors for evaluating object detection
performance. For single-image evaluation, this dictionary may be
produced from eval_util.result_dict_for_single_example(). If multi-image
evaluation, `eval_dict` should contain the fields
'num_groundtruth_boxes_per_image' and 'num_det_boxes_per_image' to
properly unpad the tensors from the batch.
Returns:
a dictionary of metric names to tuple of value_op and update_op that can
be used as eval metric ops in tf.estimator.EstimatorSpec. Note that all
update ops must be run together and similarly all value ops must be run
together to guarantee correct behaviour.
"""
update_op = self.add_eval_dict(eval_dict)
category = self._category_name
metric_names = [
'Keypoints_Precision/mAP ByCategory/{}'.format(category),
'Keypoints_Precision/mAP@.50IOU ByCategory/{}'.format(category),
'Keypoints_Precision/mAP@.75IOU ByCategory/{}'.format(category),
'Keypoints_Precision/mAP (large) ByCategory/{}'.format(category),
'Keypoints_Precision/mAP (medium) ByCategory/{}'.format(category),
'Keypoints_Recall/AR@1 ByCategory/{}'.format(category),
'Keypoints_Recall/AR@10 ByCategory/{}'.format(category),
'Keypoints_Recall/AR@100 ByCategory/{}'.format(category),
'Keypoints_Recall/AR@100 (large) ByCategory/{}'.format(category),
'Keypoints_Recall/AR@100 (medium) ByCategory/{}'.format(category)
]
def first_value_func():
self._metrics = self.evaluate()
self.clear()
return np.float32(self._metrics[metric_names[0]])
def value_func_factory(metric_name):
def value_func():
return np.float32(self._metrics[metric_name])
return value_func
# Ensure that the metrics are only evaluated once.
first_value_op = tf.py_func(first_value_func, [], tf.float32)
eval_metric_ops = {metric_names[0]: (first_value_op, update_op)}
with tf.control_dependencies([first_value_op]):
for metric_name in metric_names[1:]:
eval_metric_ops[metric_name] = (tf.py_func(
value_func_factory(metric_name), [], np.float32), update_op)
return eval_metric_ops
class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
"""Class to evaluate COCO detection metrics."""
......
......@@ -37,6 +37,25 @@ def _get_categories_list():
}]
def _get_category_keypoints_dict():
return {
'person': [{
'id': 0,
'name': 'left_eye'
}, {
'id': 3,
'name': 'right_eye'
}],
'dog': [{
'id': 1,
'name': 'tail_start'
}, {
'id': 2,
'name': 'mouth'
}]
}
class CocoDetectionEvaluationTest(tf.test.TestCase):
def testGetOneMAPWithMatchingGroundtruthAndDetections(self):
......@@ -287,7 +306,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_classes: np.array([2])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
......@@ -380,7 +399,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_classes: np.array([1, 2, 2, 3])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
......@@ -476,7 +495,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
np.array([2, 2])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
......@@ -538,7 +557,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_classes: np.array([[1], [3], [2]])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
......@@ -625,7 +644,7 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
self.assertEqual(len(coco_evaluator._detection_boxes_list), 5)
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0)
......@@ -647,6 +666,696 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
self.assertFalse(coco_evaluator._image_ids)
class CocoKeypointEvaluationTest(tf.test.TestCase):
def testGetOneMAPWithMatchingKeypoints(self):
"""Tests that correct mAP for keypoints is calculated."""
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[2, 0, 0, 2]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image1',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]])
})
coco_evaluator.add_single_ground_truth_image_info(
image_id='image2',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[50., 50., 100., 100.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[75., 76.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [77., 78.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[2, 0, 0, 2]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image2',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[50., 50., 100., 100.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[75., 76.], [5., 6.], [7., 8.], [77., 78.]]])
})
metrics = coco_evaluator.evaluate()
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
def testGroundtruthListValues(self):
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'), float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[2, 0, 0, 2]]),
standard_fields.InputDataFields.groundtruth_area: np.array([15.])
})
gt_dict = coco_evaluator._groundtruth_list[0]
self.assertEqual(gt_dict['id'], 1)
self.assertAlmostEqual(gt_dict['bbox'], [100.0, 100.0, 100.0, 100.0])
self.assertAlmostEqual(
gt_dict['keypoints'], [160.0, 150.0, 2, 180.0, 170.0, 2])
self.assertEqual(gt_dict['num_keypoints'], 2)
self.assertAlmostEqual(gt_dict['area'], 15.0)
def testKeypointVisibilitiesAreOptional(self):
"""Tests that evaluator works when visibilities aren't provided."""
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image1',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]])
})
coco_evaluator.add_single_ground_truth_image_info(
image_id='image2',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[50., 50., 100., 100.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[75., 76.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [77., 78.]]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image2',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[50., 50., 100., 100.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[75., 76.], [5., 6.], [7., 8.], [77., 78.]]])
})
metrics = coco_evaluator.evaluate()
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
def testFiltersDetectionsFromOtherCategories(self):
"""Tests that the evaluator ignores detections from other categories."""
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=2, category_keypoints=category_keypoint_dict['person'],
class_text='dog')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [170., 180.], [110., 120.],
[130., 140.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[2, 2, 2, 2]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image1',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.9]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[150., 160.], [170., 180.], [110., 120.],
[130., 140.]]])
})
metrics = coco_evaluator.evaluate()
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/dog'],
-1.0)
def testHandlesUnlabeledKeypointData(self):
"""Tests that the evaluator handles missing keypoints GT."""
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[0, 0, 0, 2]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image1',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[50., 60.], [1., 2.], [3., 4.], [170., 180.]]])
})
metrics = coco_evaluator.evaluate()
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
def testIgnoresCrowdAnnotations(self):
"""Tests that the evaluator ignores GT marked as crowd."""
category_keypoint_dict = _get_category_keypoints_dict()
coco_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
coco_evaluator.add_single_ground_truth_image_info(
image_id='image1',
groundtruth_dict={
standard_fields.InputDataFields.groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.InputDataFields.groundtruth_classes:
np.array([1]),
standard_fields.InputDataFields.groundtruth_is_crowd:
np.array([1]),
standard_fields.InputDataFields.groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
standard_fields.InputDataFields.groundtruth_keypoint_visibilities:
np.array([[2, 0, 0, 2]])
})
coco_evaluator.add_single_detected_image_info(
image_id='image1',
detections_dict={
standard_fields.DetectionResultFields.detection_boxes:
np.array([[100., 100., 200., 200.]]),
standard_fields.DetectionResultFields.detection_scores:
np.array([.8]),
standard_fields.DetectionResultFields.detection_classes:
np.array([1]),
standard_fields.DetectionResultFields.detection_keypoints:
np.array([[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]])
})
metrics = coco_evaluator.evaluate()
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
-1.0)
class CocoKeypointEvaluationPyFuncTest(tf.test.TestCase):
def testGetOneMAPWithMatchingKeypoints(self):
category_keypoint_dict = _get_category_keypoints_dict()
coco_keypoint_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
image_id = tf.placeholder(tf.string, shape=())
groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4))
groundtruth_classes = tf.placeholder(tf.float32, shape=(None))
groundtruth_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
detection_boxes = tf.placeholder(tf.float32, shape=(None, 4))
detection_scores = tf.placeholder(tf.float32, shape=(None))
detection_classes = tf.placeholder(tf.float32, shape=(None))
detection_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
input_data_fields.groundtruth_keypoints: groundtruth_keypoints,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes,
detection_fields.detection_keypoints: detection_keypoints,
}
eval_metric_ops = coco_keypoint_evaluator.get_estimator_eval_metric_ops(
eval_dict)
_, update_op = eval_metric_ops['Keypoints_Precision/mAP ByCategory/person']
with self.test_session() as sess:
sess.run(
update_op,
feed_dict={
image_id:
'image1',
groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
groundtruth_classes:
np.array([1]),
groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
detection_boxes:
np.array([[100., 100., 200., 200.]]),
detection_scores:
np.array([.8]),
detection_classes:
np.array([1]),
detection_keypoints:
np.array([[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]])
})
sess.run(
update_op,
feed_dict={
image_id:
'image2',
groundtruth_boxes:
np.array([[50., 50., 100., 100.]]),
groundtruth_classes:
np.array([1]),
groundtruth_keypoints:
np.array([[[75., 76.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [77., 78.]]]),
detection_boxes:
np.array([[50., 50., 100., 100.]]),
detection_scores:
np.array([.7]),
detection_classes:
np.array([1]),
detection_keypoints:
np.array([[[75., 76.], [5., 6.], [7., 8.], [77., 78.]]])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.50IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.75IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (medium) ByCategory/person'], 1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@1 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@10 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@100 ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (medium) ByCategory/person'], 1.0)
self.assertFalse(coco_keypoint_evaluator._groundtruth_list)
self.assertFalse(coco_keypoint_evaluator._detection_boxes_list)
self.assertFalse(coco_keypoint_evaluator._image_ids)
def testGetOneMAPWithMatchingKeypointsAndVisibilities(self):
category_keypoint_dict = _get_category_keypoints_dict()
coco_keypoint_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
image_id = tf.placeholder(tf.string, shape=())
groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4))
groundtruth_classes = tf.placeholder(tf.float32, shape=(None))
groundtruth_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
groundtruth_keypoint_visibilities = tf.placeholder(
tf.float32, shape=(None, 4))
detection_boxes = tf.placeholder(tf.float32, shape=(None, 4))
detection_scores = tf.placeholder(tf.float32, shape=(None))
detection_classes = tf.placeholder(tf.float32, shape=(None))
detection_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key:
image_id,
input_data_fields.groundtruth_boxes:
groundtruth_boxes,
input_data_fields.groundtruth_classes:
groundtruth_classes,
input_data_fields.groundtruth_keypoints:
groundtruth_keypoints,
input_data_fields.groundtruth_keypoint_visibilities:
groundtruth_keypoint_visibilities,
detection_fields.detection_boxes:
detection_boxes,
detection_fields.detection_scores:
detection_scores,
detection_fields.detection_classes:
detection_classes,
detection_fields.detection_keypoints:
detection_keypoints,
}
eval_metric_ops = coco_keypoint_evaluator.get_estimator_eval_metric_ops(
eval_dict)
_, update_op = eval_metric_ops['Keypoints_Precision/mAP ByCategory/person']
with self.test_session() as sess:
sess.run(
update_op,
feed_dict={
image_id:
'image1',
groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
groundtruth_classes:
np.array([1]),
groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
groundtruth_keypoint_visibilities:
np.array([[0, 0, 0, 2]]),
detection_boxes:
np.array([[100., 100., 200., 200.]]),
detection_scores:
np.array([.8]),
detection_classes:
np.array([1]),
detection_keypoints:
np.array([[[50., 60.], [1., 2.], [3., 4.], [170., 180.]]])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.50IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.75IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (medium) ByCategory/person'], -1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@1 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@10 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@100 ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (medium) ByCategory/person'], -1.0)
self.assertFalse(coco_keypoint_evaluator._groundtruth_list)
self.assertFalse(coco_keypoint_evaluator._detection_boxes_list)
self.assertFalse(coco_keypoint_evaluator._image_ids)
def testGetOneMAPWithMatchingKeypointsIsAnnotated(self):
category_keypoint_dict = _get_category_keypoints_dict()
coco_keypoint_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
image_id = tf.placeholder(tf.string, shape=())
groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4))
groundtruth_classes = tf.placeholder(tf.float32, shape=(None))
groundtruth_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
is_annotated = tf.placeholder(tf.bool, shape=())
detection_boxes = tf.placeholder(tf.float32, shape=(None, 4))
detection_scores = tf.placeholder(tf.float32, shape=(None))
detection_classes = tf.placeholder(tf.float32, shape=(None))
detection_keypoints = tf.placeholder(tf.float32, shape=(None, 4, 2))
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
input_data_fields.groundtruth_keypoints: groundtruth_keypoints,
'is_annotated': is_annotated,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes,
detection_fields.detection_keypoints: detection_keypoints,
}
eval_metric_ops = coco_keypoint_evaluator.get_estimator_eval_metric_ops(
eval_dict)
_, update_op = eval_metric_ops['Keypoints_Precision/mAP ByCategory/person']
with self.test_session() as sess:
sess.run(
update_op,
feed_dict={
image_id:
'image1',
groundtruth_boxes:
np.array([[100., 100., 200., 200.]]),
groundtruth_classes:
np.array([1]),
groundtruth_keypoints:
np.array([[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]]),
is_annotated:
True,
detection_boxes:
np.array([[100., 100., 200., 200.]]),
detection_scores:
np.array([.8]),
detection_classes:
np.array([1]),
detection_keypoints:
np.array([[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]])
})
sess.run(
update_op,
feed_dict={
image_id:
'image2',
groundtruth_boxes:
np.array([[50., 50., 100., 100.]]),
groundtruth_classes:
np.array([1]),
groundtruth_keypoints:
np.array([[[75., 76.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [77., 78.]]]),
is_annotated:
True,
detection_boxes:
np.array([[50., 50., 100., 100.]]),
detection_scores:
np.array([.7]),
detection_classes:
np.array([1]),
detection_keypoints:
np.array([[[75., 76.], [5., 6.], [7., 8.], [77., 78.]]])
})
sess.run(
update_op,
feed_dict={
image_id:
'image3',
groundtruth_boxes:
np.zeros((0, 4)),
groundtruth_classes:
np.zeros((0)),
groundtruth_keypoints:
np.zeros((0, 4, 2)),
is_annotated:
False, # Note that this image isn't annotated.
detection_boxes:
np.array([[25., 25., 50., 50.], [25., 25., 70., 50.],
[25., 25., 80., 50.], [25., 25., 90., 50.]]),
detection_scores:
np.array([0.6, 0.7, 0.8, 0.9]),
detection_classes:
np.array([1, 2, 2, 3]),
detection_keypoints:
np.array([[[0., 0.], [0., 0.], [0., 0.], [0., 0.]]])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.50IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.75IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (medium) ByCategory/person'], 1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@1 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@10 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@100 ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (medium) ByCategory/person'], 1.0)
self.assertFalse(coco_keypoint_evaluator._groundtruth_list)
self.assertFalse(coco_keypoint_evaluator._detection_boxes_list)
self.assertFalse(coco_keypoint_evaluator._image_ids)
def testGetOneMAPWithMatchingKeypointsBatched(self):
category_keypoint_dict = _get_category_keypoints_dict()
coco_keypoint_evaluator = coco_evaluation.CocoKeypointEvaluator(
category_id=1, category_keypoints=category_keypoint_dict['person'],
class_text='person')
batch_size = 2
image_id = tf.placeholder(tf.string, shape=(batch_size))
groundtruth_boxes = tf.placeholder(tf.float32, shape=(batch_size, None, 4))
groundtruth_classes = tf.placeholder(tf.float32, shape=(batch_size, None))
groundtruth_keypoints = tf.placeholder(
tf.float32, shape=(batch_size, None, 4, 2))
detection_boxes = tf.placeholder(tf.float32, shape=(batch_size, None, 4))
detection_scores = tf.placeholder(tf.float32, shape=(batch_size, None))
detection_classes = tf.placeholder(tf.float32, shape=(batch_size, None))
detection_keypoints = tf.placeholder(
tf.float32, shape=(batch_size, None, 4, 2))
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
input_data_fields.groundtruth_keypoints: groundtruth_keypoints,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes,
detection_fields.detection_keypoints: detection_keypoints
}
eval_metric_ops = coco_keypoint_evaluator.get_estimator_eval_metric_ops(
eval_dict)
_, update_op = eval_metric_ops['Keypoints_Precision/mAP ByCategory/person']
with self.test_session() as sess:
sess.run(
update_op,
feed_dict={
image_id: ['image1', 'image2'],
groundtruth_boxes:
np.array([[[100., 100., 200., 200.]], [[50., 50., 100.,
100.]]]),
groundtruth_classes:
np.array([[1], [3]]),
groundtruth_keypoints:
np.array([[[[150., 160.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [170., 180.]]],
[[[75., 76.], [float('nan'),
float('nan')],
[float('nan'), float('nan')], [77., 78.]]]]),
detection_boxes:
np.array([[[100., 100., 200., 200.]], [[50., 50., 100.,
100.]]]),
detection_scores:
np.array([[.8], [.7]]),
detection_classes:
np.array([[1], [3]]),
detection_keypoints:
np.array([[[[150., 160.], [1., 2.], [3., 4.], [170., 180.]]],
[[[75., 76.], [5., 6.], [7., 8.], [77., 78.]]]])
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['Keypoints_Precision/mAP ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.50IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP@.75IOU ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Precision/mAP (medium) ByCategory/person'], -1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@1 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@10 ByCategory/person'],
1.0)
self.assertAlmostEqual(metrics['Keypoints_Recall/AR@100 ByCategory/person'],
1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (large) ByCategory/person'], 1.0)
self.assertAlmostEqual(
metrics['Keypoints_Recall/AR@100 (medium) ByCategory/person'], -1.0)
self.assertFalse(coco_keypoint_evaluator._groundtruth_list)
self.assertFalse(coco_keypoint_evaluator._detection_boxes_list)
self.assertFalse(coco_keypoint_evaluator._image_ids)
class CocoMaskEvaluationTest(tf.test.TestCase):
def testGetOneMAPWithMatchingGroundtruthAndDetections(self):
......@@ -824,7 +1533,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase):
mode='constant')
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP'], 1.0)
......@@ -924,7 +1633,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase):
axis=0)
})
metrics = {}
for key, (value_op, _) in eval_metric_ops.iteritems():
for key, (value_op, _) in eval_metric_ops.items():
metrics[key] = value_op
metrics = sess.run(metrics)
self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP'], 1.0)
......
......@@ -157,7 +157,7 @@ class COCOEvalWrapper(cocoeval.COCOeval):
"""
def __init__(self, groundtruth=None, detections=None, agnostic_mode=False,
iou_type='bbox'):
iou_type='bbox', oks_sigmas=None):
"""COCOEvalWrapper constructor.
Note that for the area-based metrics to be meaningful, detection and
......@@ -170,12 +170,16 @@ class COCOEvalWrapper(cocoeval.COCOeval):
detections
agnostic_mode: boolean (default: False). If True, evaluation ignores
class labels, treating all detections as proposals.
iou_type: IOU type to use for evaluation. Supports `bbox` or `segm`.
iou_type: IOU type to use for evaluation. Supports `bbox', `segm`,
`keypoints`.
oks_sigmas: Float numpy array holding the OKS variances for keypoints.
"""
cocoeval.COCOeval.__init__(self, groundtruth, detections,
iouType=iou_type)
cocoeval.COCOeval.__init__(self, groundtruth, detections, iouType=iou_type)
if oks_sigmas is not None:
self.params.kpt_oks_sigmas = oks_sigmas
if agnostic_mode:
self.params.useCats = 0
self._iou_type = iou_type
def GetCategory(self, category_id):
"""Fetches dictionary holding category information given category id.
......@@ -198,7 +202,7 @@ class COCOEvalWrapper(cocoeval.COCOeval):
def ComputeMetrics(self,
include_metrics_per_category=False,
all_metrics_per_category=False):
"""Computes detection metrics.
"""Computes detection/keypoint metrics.
Args:
include_metrics_per_category: If True, will include metrics per category.
......@@ -214,7 +218,7 @@ class COCOEvalWrapper(cocoeval.COCOeval):
'Precision/mAP@.50IOU': mean average precision at 50% IOU
'Precision/mAP@.75IOU': mean average precision at 75% IOU
'Precision/mAP (small)': mean average precision for small objects
(area < 32^2 pixels)
(area < 32^2 pixels). NOTE: not present for 'keypoints'
'Precision/mAP (medium)': mean average precision for medium sized
objects (32^2 pixels < area < 96^2 pixels)
'Precision/mAP (large)': mean average precision for large objects
......@@ -223,7 +227,7 @@ class COCOEvalWrapper(cocoeval.COCOeval):
'Recall/AR@10': average recall with 10 detections
'Recall/AR@100': average recall with 100 detections
'Recall/AR@100 (small)': average recall for small objects with 100
detections
detections. NOTE: not present for 'keypoints'
'Recall/AR@100 (medium)': average recall for medium objects with 100
detections
'Recall/AR@100 (large)': average recall for large objects with 100
......@@ -243,8 +247,9 @@ class COCOEvalWrapper(cocoeval.COCOeval):
self.accumulate()
self.summarize()
summary_metrics = OrderedDict([
('Precision/mAP', self.stats[0]),
summary_metrics = {}
if self._iou_type in ['bbox', 'segm']:
summary_metrics = OrderedDict([('Precision/mAP', self.stats[0]),
('Precision/mAP@.50IOU', self.stats[1]),
('Precision/mAP@.75IOU', self.stats[2]),
('Precision/mAP (small)', self.stats[3]),
......@@ -255,8 +260,31 @@ class COCOEvalWrapper(cocoeval.COCOeval):
('Recall/AR@100', self.stats[8]),
('Recall/AR@100 (small)', self.stats[9]),
('Recall/AR@100 (medium)', self.stats[10]),
('Recall/AR@100 (large)', self.stats[11])
])
('Recall/AR@100 (large)', self.stats[11])])
elif self._iou_type == 'keypoints':
category_id = self.GetCategoryIdList()[0]
category_name = self.GetCategory(category_id)['name']
summary_metrics = OrderedDict([])
summary_metrics['Precision/mAP ByCategory/{}'.format(
category_name)] = self.stats[0]
summary_metrics['Precision/mAP@.50IOU ByCategory/{}'.format(
category_name)] = self.stats[1]
summary_metrics['Precision/mAP@.75IOU ByCategory/{}'.format(
category_name)] = self.stats[2]
summary_metrics['Precision/mAP (medium) ByCategory/{}'.format(
category_name)] = self.stats[3]
summary_metrics['Precision/mAP (large) ByCategory/{}'.format(
category_name)] = self.stats[4]
summary_metrics['Recall/AR@1 ByCategory/{}'.format(
category_name)] = self.stats[5]
summary_metrics['Recall/AR@10 ByCategory/{}'.format(
category_name)] = self.stats[6]
summary_metrics['Recall/AR@100 ByCategory/{}'.format(
category_name)] = self.stats[7]
summary_metrics['Recall/AR@100 (medium) ByCategory/{}'.format(
category_name)] = self.stats[8]
summary_metrics['Recall/AR@100 (large) ByCategory/{}'.format(
category_name)] = self.stats[9]
if not include_metrics_per_category:
return summary_metrics, {}
if not hasattr(self, 'category_stats'):
......@@ -333,8 +361,11 @@ def ExportSingleImageGroundtruthToCoco(image_id,
category_id_set,
groundtruth_boxes,
groundtruth_classes,
groundtruth_keypoints=None,
groundtruth_keypoint_visibilities=None,
groundtruth_masks=None,
groundtruth_is_crowd=None):
groundtruth_is_crowd=None,
groundtruth_area=None):
"""Export groundtruth of a single image to COCO format.
This function converts groundtruth detection annotations represented as numpy
......@@ -356,10 +387,19 @@ def ExportSingleImageGroundtruthToCoco(image_id,
category_id_set are dropped.
groundtruth_boxes: numpy array (float32) with shape [num_gt_boxes, 4]
groundtruth_classes: numpy array (int) with shape [num_gt_boxes]
groundtruth_keypoints: optional float numpy array of keypoints
with shape [num_gt_boxes, num_keypoints, 2].
groundtruth_keypoint_visibilities: optional integer numpy array of keypoint
visibilities with shape [num_gt_boxes, num_keypoints]. Integer is treated
as an enum with 0=not labels, 1=labeled but not visible and 2=labeled and
visible.
groundtruth_masks: optional uint8 numpy array of shape [num_detections,
image_height, image_width] containing detection_masks.
groundtruth_is_crowd: optional numpy array (int) with shape [num_gt_boxes]
indicating whether groundtruth boxes are crowd.
groundtruth_area: numpy array (float32) with shape [num_gt_boxes]. If
provided, then the area values (in the original absolute coordinates) will
be populated instead of calculated from bounding box coordinates.
Returns:
a list of groundtruth annotations for a single image in the COCO format.
......@@ -390,10 +430,20 @@ def ExportSingleImageGroundtruthToCoco(image_id,
has_is_crowd = groundtruth_is_crowd is not None
if has_is_crowd and len(groundtruth_is_crowd.shape) != 1:
raise ValueError('groundtruth_is_crowd is expected to be of rank 1.')
has_keypoints = groundtruth_keypoints is not None
has_keypoint_visibilities = groundtruth_keypoint_visibilities is not None
if has_keypoints and not has_keypoint_visibilities:
groundtruth_keypoint_visibilities = np.full(
(num_boxes, groundtruth_keypoints.shape[1]), 2)
groundtruth_list = []
for i in range(num_boxes):
if groundtruth_classes[i] in category_id_set:
iscrowd = groundtruth_is_crowd[i] if has_is_crowd else 0
if groundtruth_area is not None and groundtruth_area[i] > 0:
area = float(groundtruth_area[i])
else:
area = float((groundtruth_boxes[i, 2] - groundtruth_boxes[i, 0]) *
(groundtruth_boxes[i, 3] - groundtruth_boxes[i, 1]))
export_dict = {
'id':
next_annotation_id + i,
......@@ -403,14 +453,27 @@ def ExportSingleImageGroundtruthToCoco(image_id,
int(groundtruth_classes[i]),
'bbox':
list(_ConvertBoxToCOCOFormat(groundtruth_boxes[i, :])),
'area':
float((groundtruth_boxes[i, 2] - groundtruth_boxes[i, 0]) *
(groundtruth_boxes[i, 3] - groundtruth_boxes[i, 1])),
'area': area,
'iscrowd':
iscrowd
}
if groundtruth_masks is not None:
export_dict['segmentation'] = _RleCompress(groundtruth_masks[i])
if has_keypoints:
keypoints = groundtruth_keypoints[i]
visibilities = np.reshape(groundtruth_keypoint_visibilities[i], [-1])
coco_keypoints = []
num_valid_keypoints = 0
for keypoint, visibility in zip(keypoints, visibilities):
# Convert from [y, x] to [x, y] as mandated by COCO.
coco_keypoints.append(float(keypoint[1]))
coco_keypoints.append(float(keypoint[0]))
coco_keypoints.append(int(visibility))
if int(visibility) > 0:
num_valid_keypoints = num_valid_keypoints + 1
export_dict['keypoints'] = coco_keypoints
export_dict['num_keypoints'] = num_valid_keypoints
groundtruth_list.append(export_dict)
return groundtruth_list
......@@ -494,7 +557,9 @@ def ExportSingleImageDetectionBoxesToCoco(image_id,
category_id_set,
detection_boxes,
detection_scores,
detection_classes):
detection_classes,
detection_keypoints=None,
detection_keypoint_visibilities=None):
"""Export detections of a single image to COCO format.
This function converts detections represented as numpy arrays to dictionaries
......@@ -514,6 +579,12 @@ def ExportSingleImageDetectionBoxesToCoco(image_id,
scored for the detection boxes.
detection_classes: integer numpy array of shape [num_detections] containing
the classes for detection boxes.
detection_keypoints: optional float numpy array of keypoints
with shape [num_detections, num_keypoints, 2].
detection_keypoint_visibilities: optional integer numpy array of keypoint
visibilities with shape [num_detections, num_keypoints]. Integer is
treated as an enum with 0=not labels, 1=labeled but not visible and
2=labeled and visible.
Returns:
a list of detection annotations for a single image in the COCO format.
......@@ -546,12 +617,33 @@ def ExportSingleImageDetectionBoxesToCoco(image_id,
detections_list = []
for i in range(num_boxes):
if detection_classes[i] in category_id_set:
detections_list.append({
'image_id': image_id,
'category_id': int(detection_classes[i]),
'bbox': list(_ConvertBoxToCOCOFormat(detection_boxes[i, :])),
'score': float(detection_scores[i])
})
export_dict = {
'image_id':
image_id,
'category_id':
int(detection_classes[i]),
'bbox':
list(_ConvertBoxToCOCOFormat(detection_boxes[i, :])),
'score':
float(detection_scores[i]),
}
if detection_keypoints is not None:
keypoints = detection_keypoints[i]
num_keypoints = keypoints.shape[0]
if detection_keypoint_visibilities is None:
detection_keypoint_visibilities = np.full((num_boxes, num_keypoints),
2)
visibilities = np.reshape(detection_keypoint_visibilities[i], [-1])
coco_keypoints = []
for keypoint, visibility in zip(keypoints, visibilities):
# Convert from [y, x] to [x, y] as mandated by COCO.
coco_keypoints.append(float(keypoint[1]))
coco_keypoints.append(float(keypoint[0]))
coco_keypoints.append(int(visibility))
export_dict['keypoints'] = coco_keypoints
export_dict['num_keypoints'] = num_keypoints
detections_list.append(export_dict)
return detections_list
......
......@@ -290,6 +290,116 @@ class CocoToolsTest(tf.test.TestCase):
self.assertEqual(annotation['iscrowd'], is_crowd[i])
self.assertEqual(annotation['id'], i + next_annotation_id)
def testSingleImageGroundtruthExportWithKeypoints(self):
boxes = np.array([[0, 0, 1, 1],
[0, 0, .5, .5],
[.5, .5, 1, 1]], dtype=np.float32)
coco_boxes = np.array([[0, 0, 1, 1],
[0, 0, .5, .5],
[.5, .5, .5, .5]], dtype=np.float32)
keypoints = np.array([[[0, 0], [0.25, 0.25], [0.75, 0.75]],
[[0, 0], [0.125, 0.125], [0.375, 0.375]],
[[0.5, 0.5], [0.75, 0.75], [1.0, 1.0]]],
dtype=np.float32)
visibilities = np.array([[2, 2, 2],
[2, 2, 0],
[2, 0, 0]], dtype=np.int32)
areas = np.array([15., 16., 17.])
classes = np.array([1, 2, 3], dtype=np.int32)
is_crowd = np.array([0, 1, 0], dtype=np.int32)
next_annotation_id = 1
# Tests exporting without passing in is_crowd (for backward compatibility).
coco_annotations = coco_tools.ExportSingleImageGroundtruthToCoco(
image_id='first_image',
category_id_set=set([1, 2, 3]),
next_annotation_id=next_annotation_id,
groundtruth_boxes=boxes,
groundtruth_classes=classes,
groundtruth_keypoints=keypoints,
groundtruth_keypoint_visibilities=visibilities,
groundtruth_area=areas)
for i, annotation in enumerate(coco_annotations):
self.assertTrue(np.all(np.isclose(annotation['bbox'], coco_boxes[i])))
self.assertEqual(annotation['image_id'], 'first_image')
self.assertEqual(annotation['category_id'], classes[i])
self.assertEqual(annotation['id'], i + next_annotation_id)
self.assertEqual(annotation['num_keypoints'], 3 - i)
self.assertEqual(annotation['area'], 15.0 + i)
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][0::3], keypoints[i, :, 1])))
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][1::3], keypoints[i, :, 0])))
self.assertTrue(
np.all(np.equal(annotation['keypoints'][2::3], visibilities[i])))
# Tests exporting with is_crowd.
coco_annotations = coco_tools.ExportSingleImageGroundtruthToCoco(
image_id='first_image',
category_id_set=set([1, 2, 3]),
next_annotation_id=next_annotation_id,
groundtruth_boxes=boxes,
groundtruth_classes=classes,
groundtruth_keypoints=keypoints,
groundtruth_keypoint_visibilities=visibilities,
groundtruth_is_crowd=is_crowd)
for i, annotation in enumerate(coco_annotations):
self.assertTrue(np.all(np.isclose(annotation['bbox'], coco_boxes[i])))
self.assertEqual(annotation['image_id'], 'first_image')
self.assertEqual(annotation['category_id'], classes[i])
self.assertEqual(annotation['iscrowd'], is_crowd[i])
self.assertEqual(annotation['id'], i + next_annotation_id)
self.assertEqual(annotation['num_keypoints'], 3 - i)
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][0::3], keypoints[i, :, 1])))
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][1::3], keypoints[i, :, 0])))
self.assertTrue(
np.all(np.equal(annotation['keypoints'][2::3], visibilities[i])))
# Testing the area values are derived from the bounding boxes.
if i == 0:
self.assertAlmostEqual(annotation['area'], 1.0)
else:
self.assertAlmostEqual(annotation['area'], 0.25)
def testSingleImageDetectionBoxesExportWithKeypoints(self):
boxes = np.array([[0, 0, 1, 1], [0, 0, .5, .5], [.5, .5, 1, 1]],
dtype=np.float32)
coco_boxes = np.array([[0, 0, 1, 1], [0, 0, .5, .5], [.5, .5, .5, .5]],
dtype=np.float32)
keypoints = np.array([[[0, 0], [0.25, 0.25], [0.75, 0.75]],
[[0, 0], [0.125, 0.125], [0.375, 0.375]],
[[0.5, 0.5], [0.75, 0.75], [1.0, 1.0]]],
dtype=np.float32)
visibilities = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=np.int32)
classes = np.array([1, 2, 3], dtype=np.int32)
scores = np.array([0.8, 0.2, 0.7], dtype=np.float32)
# Tests exporting without passing in is_crowd (for backward compatibility).
coco_annotations = coco_tools.ExportSingleImageDetectionBoxesToCoco(
image_id='first_image',
category_id_set=set([1, 2, 3]),
detection_boxes=boxes,
detection_scores=scores,
detection_classes=classes,
detection_keypoints=keypoints,
detection_keypoint_visibilities=visibilities)
for i, annotation in enumerate(coco_annotations):
self.assertTrue(np.all(np.isclose(annotation['bbox'], coco_boxes[i])))
self.assertEqual(annotation['image_id'], 'first_image')
self.assertEqual(annotation['category_id'], classes[i])
self.assertTrue(np.all(np.isclose(annotation['bbox'], coco_boxes[i])))
self.assertEqual(annotation['score'], scores[i])
self.assertEqual(annotation['num_keypoints'], 3)
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][0::3], keypoints[i, :, 1])))
self.assertTrue(
np.all(np.isclose(annotation['keypoints'][1::3], keypoints[i, :, 0])))
self.assertTrue(
np.all(np.equal(annotation['keypoints'][2::3], visibilities[i])))
if __name__ == '__main__':
tf.test.main()
......@@ -24,6 +24,7 @@ import zlib
import numpy as np
import pandas as pd
from pycocotools import mask as coco_mask
import six
import tensorflow as tf
from object_detection.core import standard_fields
......@@ -50,7 +51,8 @@ def encode_mask(mask_to_encode):
mask_to_encode = mask_to_encode.astype(np.uint8)
mask_to_encode = np.asfortranarray(mask_to_encode)
encoded_mask = coco_mask.encode(mask_to_encode)[0]['counts']
compressed_mask = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
compressed_mask = zlib.compress(six.ensure_binary(encoded_mask),
zlib.Z_BEST_COMPRESSION)
base64_mask = base64.b64encode(compressed_mask)
return base64_mask
......
......@@ -44,8 +44,8 @@ class StringParser(data_parser.DataToNumpyParser):
self.field_name = field_name
def parse(self, tf_example):
return "".join(tf_example.features.feature[self.field_name]
.bytes_list.value) if tf_example.features.feature[
return b"".join(tf_example.features.feature[
self.field_name].bytes_list.value) if tf_example.features.feature[
self.field_name].HasField("bytes_list") else None
......
......@@ -34,7 +34,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def testParseDetectionsAndGT(self):
source_id = 'abc.jpg'
source_id = b'abc.jpg'
# y_min, x_min, y_max, x_max
object_bb = np.array([[0.0, 0.5, 0.3], [0.0, 0.1, 0.6], [1.0, 0.6, 0.8],
[1.0, 0.6, 0.7]]).transpose()
......@@ -129,7 +129,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
results_dict[fields.InputDataFields.groundtruth_image_classes])
def testParseString(self):
string_val = 'abc'
string_val = b'abc'
features = {'string': self._BytesFeature(string_val)}
example = tf.train.Example(features=tf.train.Features(feature=features))
......
......@@ -21,7 +21,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import training as contrib_training
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
def create_hparams(hparams_overrides=None):
......@@ -34,7 +40,7 @@ def create_hparams(hparams_overrides=None):
Returns:
The hyperparameters as a tf.HParams object.
"""
hparams = tf.contrib.training.HParams(
hparams = contrib_training.HParams(
# Whether a fine tuning checkpoint (provided in the pipeline config)
# should be loaded for training.
load_pretrained=True)
......
......@@ -38,6 +38,18 @@ from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import learn as contrib_learn
from tensorflow.contrib import tpu as contrib_tpu
from tensorflow.contrib import training as contrib_training
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
'get_configs_from_pipeline_file':
......@@ -76,8 +88,13 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth)
'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
is_crowd annotations (if provided in groundtruth).
'groundtruth_area': [batch_size, num_boxes] float32 tensor indicating
the area (in the original absolute coordinates) of annotations (if
provided in groundtruth).
'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
of groundtruth boxes per image..
'groundtruth_keypoints': [batch_size, num_boxes, num_keypoints, 2] float32
tensor of keypoints (if provided in groundtruth).
class_agnostic: Boolean indicating whether detections are class agnostic.
"""
input_data_fields = fields.InputDataFields()
......@@ -107,6 +124,20 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))
if detection_model.groundtruth_has_field(input_data_fields.groundtruth_area):
groundtruth[input_data_fields.groundtruth_area] = tf.stack(
detection_model.groundtruth_lists(input_data_fields.groundtruth_area))
if detection_model.groundtruth_has_field(fields.BoxListFields.keypoints):
groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.keypoints))
if detection_model.groundtruth_has_field(
fields.BoxListFields.keypoint_visibilities):
groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.keypoint_visibilities))
groundtruth[input_data_fields.num_groundtruth_boxes] = (
tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
return groundtruth
......@@ -161,6 +192,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints,
fields.InputDataFields.groundtruth_keypoint_visibilities,
fields.InputDataFields.groundtruth_group_of,
fields.InputDataFields.groundtruth_difficult,
fields.InputDataFields.groundtruth_is_crowd,
......@@ -206,6 +238,10 @@ def provide_groundtruth(model, labels):
gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
gt_keypoint_visibilities_list = None
if fields.InputDataFields.groundtruth_keypoint_visibilities in labels:
gt_keypoint_visibilities_list = labels[
fields.InputDataFields.groundtruth_keypoint_visibilities]
gt_weights_list = None
if fields.InputDataFields.groundtruth_weights in labels:
gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
......@@ -216,14 +252,24 @@ def provide_groundtruth(model, labels):
gt_is_crowd_list = None
if fields.InputDataFields.groundtruth_is_crowd in labels:
gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
gt_area_list = None
if fields.InputDataFields.groundtruth_area in labels:
gt_area_list = labels[fields.InputDataFields.groundtruth_area]
gt_labeled_classes = None
if fields.InputDataFields.groundtruth_labeled_classes in labels:
gt_labeled_classes = labels[
fields.InputDataFields.groundtruth_labeled_classes]
model.provide_groundtruth(
groundtruth_boxes_list=gt_boxes_list,
groundtruth_classes_list=gt_classes_list,
groundtruth_confidences_list=gt_confidences_list,
groundtruth_labeled_classes=gt_labeled_classes,
groundtruth_masks_list=gt_masks_list,
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
groundtruth_weights_list=gt_weights_list,
groundtruth_is_crowd_list=gt_is_crowd_list)
groundtruth_is_crowd_list=gt_is_crowd_list,
groundtruth_area_list=gt_area_list)
def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
......@@ -296,23 +342,26 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
provide_groundtruth(detection_model, labels)
preprocessed_images = features[fields.InputDataFields.image]
side_inputs = detection_model.get_side_inputs(features)
if use_tpu and train_config.use_bfloat16:
with tf.contrib.tpu.bfloat16_scope():
with contrib_tpu.bfloat16_scope():
prediction_dict = detection_model.predict(
preprocessed_images,
features[fields.InputDataFields.true_image_shape])
features[fields.InputDataFields.true_image_shape], **side_inputs)
prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
else:
prediction_dict = detection_model.predict(
preprocessed_images,
features[fields.InputDataFields.true_image_shape])
features[fields.InputDataFields.true_image_shape], **side_inputs)
def postprocess_wrapper(args):
return detection_model.postprocess(args[0], args[1])
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
if use_tpu and postprocess_on_cpu:
detections = tf.contrib.tpu.outside_compilation(
detections = contrib_tpu.outside_compilation(
postprocess_wrapper,
(prediction_dict,
features[fields.InputDataFields.true_image_shape]))
......@@ -354,6 +403,11 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
available_var_map)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
if (mode == tf.estimator.ModeKeys.EVAL and
eval_config.use_dummy_loss_in_eval):
total_loss = tf.constant(1.0)
losses_dict = {'Loss/total_loss': total_loss}
else:
losses_dict = detection_model.loss(
prediction_dict, features[fields.InputDataFields.true_image_shape])
losses = [loss_tensor for loss_tensor in losses_dict.values()]
......@@ -383,8 +437,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
if mode == tf.estimator.ModeKeys.TRAIN:
if use_tpu:
training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
training_optimizer)
training_optimizer = contrib_tpu.CrossShardOptimizer(training_optimizer)
# Optionally freeze some layers by setting their gradients to be zero.
trainable_variables = None
......@@ -394,7 +447,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
exclude_variables = (
train_config.freeze_variables
if train_config.freeze_variables else None)
trainable_variables = tf.contrib.framework.filter_variables(
trainable_variables = contrib_framework.filter_variables(
tf.trainable_variables(),
include_patterns=include_variables,
exclude_patterns=exclude_variables)
......@@ -409,7 +462,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
summaries = [] if use_tpu else None
if train_config.summarize_gradients:
summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
train_op = tf.contrib.layers.optimize_loss(
train_op = contrib_layers.optimize_loss(
loss=total_loss,
global_step=global_step,
learning_rate=None,
......@@ -468,12 +521,16 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
eval_input_config.label_map_path)
vis_metric_ops = None
if not use_tpu and use_original_images:
keypoint_edges = [
(kp.start, kp.end) for kp in eval_config.keypoint_edge]
eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
category_index,
max_examples_to_draw=eval_config.num_visualizations,
max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
min_score_thresh=eval_config.min_score_threshold,
use_normalized_coordinates=False)
use_normalized_coordinates=False,
keypoint_edges=keypoint_edges or None)
vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
eval_dict)
......@@ -500,7 +557,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
# EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
return tf.contrib.tpu.TPUEstimatorSpec(
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
scaffold_fn=scaffold_fn,
predictions=detections,
......@@ -535,7 +592,7 @@ def create_estimator_and_inputs(run_config,
pipeline_config_path,
config_override=None,
train_steps=None,
sample_1_of_n_eval_examples=None,
sample_1_of_n_eval_examples=1,
sample_1_of_n_eval_on_train_examples=1,
model_fn_creator=create_model_fn,
use_tpu_estimator=False,
......@@ -681,7 +738,7 @@ def create_estimator_and_inputs(run_config,
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
postprocess_on_cpu)
if use_tpu_estimator:
estimator = tf.contrib.tpu.TPUEstimator(
estimator = contrib_tpu.TPUEstimator(
model_fn=model_fn,
train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval.
......@@ -785,7 +842,7 @@ def continuous_eval(estimator, model_dir, input_fn, train_steps, name):
tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
return True
for ckpt in tf.contrib.training.checkpoints_iterator(
for ckpt in contrib_training.checkpoints_iterator(
model_dir, min_interval_secs=180, timeout=None,
timeout_fn=terminate_eval):
......@@ -862,11 +919,11 @@ def populate_experiment(run_config,
train_steps = train_and_eval_dict['train_steps']
export_strategies = [
tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
contrib_learn.utils.saved_model_export_utils.make_export_strategy(
serving_input_fn=predict_input_fn)
]
return tf.contrib.learn.Experiment(
return contrib_learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fns[0],
......
......@@ -39,6 +39,9 @@ from object_detection.utils import config_util
# 'ssd_inception_v2_pets', 'faster_rcnn_resnet50_pets'
MODEL_NAME_FOR_TEST = 'ssd_inception_v2_pets'
# Model for testing keypoints.
MODEL_NAME_FOR_KEYPOINTS_TEST = 'ssd_mobilenet_v1_fpp'
def _get_data_path():
"""Returns an absolute path to TFRecord file."""
......@@ -48,6 +51,10 @@ def _get_data_path():
def get_pipeline_config_path(model_name):
"""Returns path to the local pipeline config file."""
if model_name == MODEL_NAME_FOR_KEYPOINTS_TEST:
return os.path.join(tf.resource_loader.get_data_files_path(), 'test_data',
model_name + '.config')
else:
return os.path.join(tf.resource_loader.get_data_files_path(), 'samples',
'configs', model_name + '.config')
......@@ -58,10 +65,19 @@ def _get_labelmap_path():
'pet_label_map.pbtxt')
def _get_keypoints_labelmap_path():
"""Returns an absolute path to label map file."""
return os.path.join(tf.resource_loader.get_data_files_path(), 'data',
'face_person_with_keypoints_label_map.pbtxt')
def _get_configs_for_model(model_name):
"""Returns configurations for model."""
filename = get_pipeline_config_path(model_name)
data_path = _get_data_path()
if model_name == MODEL_NAME_FOR_KEYPOINTS_TEST:
label_map_path = _get_keypoints_labelmap_path()
else:
label_map_path = _get_labelmap_path()
configs = config_util.get_configs_from_pipeline_file(filename)
override_dict = {
......@@ -213,6 +229,17 @@ class ModelLibTest(tf.test.TestCase):
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, 'eval')
def test_model_fn_in_keypoints_eval_mode(self):
"""Tests the model function in EVAL mode with keypoints config."""
configs = _get_configs_for_model(MODEL_NAME_FOR_KEYPOINTS_TEST)
estimator_spec = self._assert_model_fn_for_train_eval(configs, 'eval')
metric_ops = estimator_spec.eval_metric_ops
self.assertIn('Keypoints_Precision/mAP ByCategory/face', metric_ops)
self.assertIn('Keypoints_Precision/mAP ByCategory/PERSON', metric_ops)
detection_keypoints = estimator_spec.predictions['detection_keypoints']
self.assertEqual(1, detection_keypoints.shape.as_list()[0])
self.assertEqual(tf.float32, detection_keypoints.dtype)
def test_model_fn_in_eval_on_train_mode(self):
"""Tests the model function in EVAL mode with train data."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import copy
import os
import time
import tensorflow as tf
......@@ -29,10 +30,20 @@ from object_detection import model_lib
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.protos import train_pb2
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vutils
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import tpu as contrib_tpu
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
......@@ -44,6 +55,12 @@ MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
#### & verify the loss output from the eval_loop method.
### TODO(kaftan): Make sure the unit tests run in TAP presubmits or Kokoro
RESTORE_MAP_ERROR_TEMPLATE = (
'Since we are restoring a v2 style checkpoint'
' restore_map was expected to return a (str -> Model) mapping,'
' but we received a ({} -> {}) mapping instead.'
)
def _compute_losses_and_predictions_dicts(
model, features, labels,
......@@ -233,12 +250,34 @@ def eager_train_step(detection_model,
gradients, _ = tf.clip_by_global_norm(gradients, clip_gradients_value)
optimizer.apply_gradients(zip(gradients, trainable_variables))
tf.compat.v2.summary.scalar('learning_rate', learning_rate, step=global_step)
tf.compat.v2.summary.image(
name='train_input_images',
step=global_step,
data=features[fields.InputDataFields.image],
max_outputs=3)
return total_loss
def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map):
"""Ensure that given dict is a valid TF v2 style restore map.
Args:
checkpoint_restore_map: A dict mapping strings to tf.keras.Model objects.
Raises:
ValueError: If they keys in checkpoint_restore_map are not strings or if
the values are not keras Model objects.
"""
for key, value in checkpoint_restore_map.items():
if not (isinstance(key, str) and isinstance(value, tf.Module)):
raise TypeError(RESTORE_MAP_ERROR_TEMPLATE.format(
key.__class__.__name__, value.__class__.__name__))
def load_fine_tune_checkpoint(
model, checkpoint_path, checkpoint_type,
model, checkpoint_path, checkpoint_type, checkpoint_version,
load_all_detection_checkpoint_vars, input_dataset,
unpad_groundtruth_tensors):
"""Load a fine tuning classification or detection checkpoint.
......@@ -260,6 +299,8 @@ def load_fine_tune_checkpoint(
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`.
checkpoint_version: train_pb2.CheckpointVersion.V1 or V2 enum indicating
whether to load checkpoints in V1 style or V2 style.
load_all_detection_checkpoint_vars: whether to load all variables (when
`fine_tune_checkpoint_type` is `detection`). If False, only variables
within the feature extractor scopes are included. Default False.
......@@ -269,6 +310,7 @@ def load_fine_tune_checkpoint(
"""
features, labels = iter(input_dataset).next()
@tf.function
def _dummy_computation_fn(features, labels):
model._is_training = False # pylint: disable=protected-access
tf.keras.backend.set_learning_phase(False)
......@@ -282,11 +324,13 @@ def load_fine_tune_checkpoint(
labels)
strategy = tf.compat.v2.distribute.get_strategy()
strategy.experimental_run_v2(
strategy.run(
_dummy_computation_fn, args=(
features,
labels,
))
if checkpoint_version == train_pb2.CheckpointVersion.V1:
var_map = model.restore_map(
fine_tune_checkpoint_type=checkpoint_type,
load_all_detection_checkpoint_vars=(
......@@ -297,6 +341,48 @@ def load_fine_tune_checkpoint(
include_global_step=False)
tf.train.init_from_checkpoint(checkpoint_path,
available_var_map)
elif checkpoint_version == train_pb2.CheckpointVersion.V2:
restore_map = model.restore_map(
fine_tune_checkpoint_type=checkpoint_type,
load_all_detection_checkpoint_vars=(
load_all_detection_checkpoint_vars))
validate_tf_v2_checkpoint_restore_map(restore_map)
ckpt = tf.train.Checkpoint(**restore_map)
ckpt.restore(checkpoint_path).assert_existing_objects_matched()
def _get_filepath(strategy, filepath):
"""Get appropriate filepath for worker.
Args:
strategy: A tf.distribute.Strategy object.
filepath: A path to where the Checkpoint object is stored.
Returns:
A temporary filepath for non-chief workers to use or the original filepath
for the chief.
"""
if strategy.extended.should_checkpoint:
return filepath
else:
# TODO(vighneshb) Replace with the public API when TF exposes it.
task_id = strategy.extended._task_id # pylint:disable=protected-access
return os.path.join(filepath, 'temp_worker_{:03d}'.format(task_id))
def _clean_temporary_directories(strategy, filepath):
"""Temporary directory clean up for MultiWorker Mirrored Strategy.
This is needed for all non-chief workers.
Args:
strategy: A tf.distribute.Strategy object.
filepath: The filepath for the temporary directory.
"""
if not strategy.extended.should_checkpoint:
if tf.io.gfile.exists(filepath) and tf.io.gfile.isdir(filepath):
tf.io.gfile.rmtree(filepath)
def train_loop(
......@@ -308,7 +394,9 @@ def train_loop(
use_tpu=False,
save_final_config=False,
export_to_tpu=None,
checkpoint_every_n=1000, **kwargs):
checkpoint_every_n=1000,
checkpoint_max_to_keep=7,
**kwargs):
"""Trains a model using eager + functions.
This method:
......@@ -340,6 +428,8 @@ def train_loop(
hparams too.
checkpoint_every_n:
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
......@@ -400,6 +490,7 @@ def train_loop(
else:
train_config.fine_tune_checkpoint_type = 'classification'
fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version
# Write the as-run pipeline config to disk.
if save_final_config:
......@@ -412,18 +503,25 @@ def train_loop(
detection_model = model_builder.build(
model_config=model_config, is_training=True)
def train_dataset_fn(input_context):
"""Callable to create train input."""
# Create the inputs.
train_input = inputs.train_input(
train_config=train_config,
train_input_config=train_input_config,
model_config=model_config,
model=detection_model)
model=detection_model,
input_context=input_context)
train_input = train_input.repeat()
return train_input
train_input = strategy.experimental_distribute_dataset(
train_input.repeat())
train_input = strategy.experimental_distribute_datasets_from_function(
train_dataset_fn)
global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step')
global_step = tf.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step',
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step)
......@@ -433,25 +531,51 @@ def train_loop(
learning_rate_fn = lambda: learning_rate
## Train the model
summary_writer = tf.compat.v2.summary.create_file_writer(model_dir + '/train')
# Get the appropriate filepath (temporary or not) based on whether the worker
# is the chief.
summary_writer_filepath = _get_filepath(strategy,
os.path.join(model_dir, 'train'))
summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
if use_tpu:
num_steps_per_iteration = 100
else:
# TODO(b/135933080) Explore setting to 100 when GPU performance issues
# are fixed.
num_steps_per_iteration = 1
with summary_writer.as_default():
with strategy.scope():
with tf.compat.v2.summary.record_if(
lambda: global_step % num_steps_per_iteration == 0):
# Load a fine-tuning checkpoint.
if fine_tune_checkpoint_path:
load_fine_tune_checkpoint(detection_model, fine_tune_checkpoint_path,
fine_tune_checkpoint_type,
fine_tune_checkpoint_version,
load_all_detection_checkpoint_vars,
train_input,
unpad_groundtruth_tensors)
ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model, optimizer=optimizer)
manager_dir = _get_filepath(strategy, model_dir)
if not strategy.extended.should_checkpoint:
checkpoint_max_to_keep = 1
manager = tf.compat.v2.train.CheckpointManager(
ckpt, model_dir, max_to_keep=7)
ckpt.restore(manager.latest_checkpoint)
ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep)
# We use the following instead of manager.latest_checkpoint because
# manager_dir does not point to the model directory when we are running
# in a worker.
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
ckpt.restore(latest_checkpoint)
def train_step_fn(features, labels):
return eager_train_step(
"""Single train step."""
loss = eager_train_step(
detection_model,
features,
labels,
......@@ -462,40 +586,62 @@ def train_loop(
clip_gradients_value=clip_gradients_value,
global_step=global_step,
num_replicas=strategy.num_replicas_in_sync)
global_step.assign_add(1)
return loss
@tf.function
def _dist_train_step(data_iterator):
"""A distributed train step."""
def _sample_and_train(strategy, train_step_fn, data_iterator):
features, labels = data_iterator.next()
per_replica_losses = strategy.experimental_run_v2(
train_step_fn, args=(
features,
labels,
))
per_replica_losses = strategy.run(
train_step_fn, args=(features, labels))
# TODO(anjalisridhar): explore if it is safe to remove the
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
return strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_losses, axis=None)
@tf.function
def _dist_train_step(data_iterator):
"""A distributed train step."""
if num_steps_per_iteration > 1:
for _ in tf.range(num_steps_per_iteration - 1):
_sample_and_train(strategy, train_step_fn, data_iterator)
return _sample_and_train(strategy, train_step_fn, data_iterator)
train_input_iter = iter(train_input)
for _ in range(train_steps - global_step.value()):
start_time = time.time()
checkpointed_step = int(global_step.value())
logged_step = global_step.value()
last_step_time = time.time()
for _ in range(global_step.value(), train_steps,
num_steps_per_iteration):
loss = _dist_train_step(train_input_iter)
global_step.assign_add(1)
end_time = time.time()
time_taken = time.time() - last_step_time
last_step_time = time.time()
tf.compat.v2.summary.scalar(
'steps_per_sec', 1.0 / (end_time - start_time),
'steps_per_sec', num_steps_per_iteration * 1.0 / time_taken,
step=global_step)
if (int(global_step.value()) % 100) == 0:
if global_step.value() - logged_step >= 100:
tf.logging.info(
'Step {} time taken {:.3f}s loss={:.3f}'.format(
global_step.value(), end_time - start_time, loss))
'Step {} per-step time {:.3f}s loss={:.3f}'.format(
global_step.value(), time_taken / num_steps_per_iteration,
loss))
logged_step = global_step.value()
if int(global_step.value()) % checkpoint_every_n == 0:
if ((int(global_step.value()) - checkpointed_step) >=
checkpoint_every_n):
manager.save()
checkpointed_step = int(global_step.value())
# Remove the checkpoint directories of the non-chief workers that
# MultiWorkerMirroredStrategy forces us to save during sync distributed
# training.
_clean_temporary_directories(strategy, manager_dir)
_clean_temporary_directories(strategy, summary_writer_filepath)
def eager_eval_loop(
......@@ -509,7 +655,7 @@ def eager_eval_loop(
This method will compute the evaluation metrics specified in the configs on
the entire evaluation dataset, then return the metrics. It will also log
the metrics to TensorBoard
the metrics to TensorBoard.
Args:
detection_model: A DetectionModel (based on Keras) to evaluate.
......@@ -578,7 +724,7 @@ def eager_eval_loop(
# TODO(kaftan): Depending on how postprocessing will work for TPUS w/
## TPUStrategy, may be good to move wrapping to a utility method
if use_tpu and postprocess_on_cpu:
detections = tf.contrib.tpu.outside_compilation(
detections = contrib_tpu.outside_compilation(
postprocess_wrapper,
(prediction_dict, features[fields.InputDataFields.true_image_shape]))
else:
......@@ -621,6 +767,36 @@ def eager_eval_loop(
if i % 100 == 0:
tf.logging.info('Finished eval step %d', i)
use_original_images = fields.InputDataFields.original_image in features
if not use_tpu and use_original_images:
# Summary for input images.
tf.compat.v2.summary.image(
name='eval_input_images',
step=global_step,
data=eval_dict['original_image'],
max_outputs=1)
# Summary for prediction/groundtruth side-by-side images.
if class_agnostic:
category_index = label_map_util.create_class_agnostic_category_index()
else:
category_index = label_map_util.create_category_index_from_labelmap(
eval_input_config.label_map_path)
keypoint_edges = [
(kp.start, kp.end) for kp in eval_config.keypoint_edge]
sbys_image_list = vutils.draw_side_by_side_evaluation_image(
eval_dict,
category_index=category_index,
max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
min_score_thresh=eval_config.min_score_threshold,
use_normalized_coordinates=False,
keypoint_edges=keypoint_edges or None)
sbys_images = tf.concat(sbys_image_list, axis=0)
tf.compat.v2.summary.image(
name='eval_side_by_side',
step=global_step,
data=sbys_images,
max_outputs=eval_config.num_visualizations)
if evaluators is None:
if class_agnostic:
evaluators = class_agnostic_evaluators
......@@ -633,6 +809,11 @@ def eager_eval_loop(
for loss_key, loss_tensor in iter(losses_dict.items()):
if loss_key not in loss_metrics:
loss_metrics[loss_key] = tf.keras.metrics.Mean()
# Skip the loss with value equal or lower than 0.0 when calculating the
# average loss since they don't usually reflect the normal loss values
# causing spurious average loss value.
if loss_tensor <= 0.0:
continue
loss_metrics[loss_key].update_state(loss_tensor)
eval_metrics = {}
......@@ -663,6 +844,7 @@ def eval_continuously(
model_dir=None,
checkpoint_dir=None,
wait_interval=180,
timeout=3600,
**kwargs):
"""Run continuous evaluation of a detection model eagerly.
......@@ -691,13 +873,13 @@ def eval_continuously(
`export_savedmodel()` exports a metagraph for serving on TPU besides the
one on CPU. If export_to_tpu is not provided, we will look for it in
hparams too.
model_dir:
Directory to output resulting evaluation summaries to.
checkpoint_dir:
Directory that contains the training checkpoints.
wait_interval:
Terminate evaluation in no new checkpoints arrive within this wait
interval (in seconds).
model_dir: Directory to output resulting evaluation summaries to.
checkpoint_dir: Directory that contains the training checkpoints.
wait_interval: The mimmum number of seconds to wait before checking for a
new checkpoint.
timeout: The maximum number of seconds to wait for a checkpoint. Execution
will terminate if no new checkpoints are found after these many seconds.
**kwargs: Additional keyword arguments for configuration override.
"""
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
......@@ -759,36 +941,12 @@ def eval_continuously(
global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
prev_checkpoint = None
waiting = False
while True:
for latest_checkpoint in tf.train.checkpoints_iterator(
checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model)
manager = tf.compat.v2.train.CheckpointManager(
ckpt, checkpoint_dir, max_to_keep=3)
latest_checkpoint = manager.latest_checkpoint
if prev_checkpoint == latest_checkpoint:
if prev_checkpoint is None:
tf.logging.info('No checkpoints found yet. Trying again in %s seconds.'
% wait_interval)
time.sleep(wait_interval)
else:
if waiting:
tf.logging.info('Terminating eval after %s seconds of no new '
'checkpoints.' % wait_interval)
break
else:
tf.logging.info('No new checkpoint found. Will try again '
'in %s seconds and terminate if no checkpoint '
'appears.' % wait_interval)
waiting = True
time.sleep(wait_interval)
else:
tf.logging.info('New checkpoint found. Starting evaluation.')
waiting = False
prev_checkpoint = latest_checkpoint
ckpt.restore(latest_checkpoint)
ckpt.restore(latest_checkpoint).expect_partial()
for eval_name, eval_input in eval_inputs:
summary_writer = tf.compat.v2.summary.create_file_writer(
......
......@@ -19,13 +19,24 @@ from __future__ import division
from __future__ import print_function
import os
import tempfile
import numpy as np
import six
import tensorflow as tf
from object_detection import inputs
from object_detection import model_hparams
from object_detection import model_lib_v2
from object_detection.builders import model_builder
from object_detection.core import model
from object_detection.protos import train_pb2
from object_detection.utils import config_util
if six.PY2:
import mock # pylint: disable=g-importing-member,g-import-not-at-top
else:
from unittest import mock # pylint: disable=g-importing-member,g-import-not-at-top
# Model for test. Current options are:
# 'ssd_mobilenet_v2_pets_keras'
......@@ -61,19 +72,10 @@ def _get_config_kwarg_overrides():
}
def _get_configs_for_model(model_name):
"""Returns configurations for model."""
filename = get_pipeline_config_path(model_name)
configs = config_util.get_configs_from_pipeline_file(filename)
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=_get_config_kwarg_overrides())
return configs
class ModelLibTest(tf.test.TestCase):
@classmethod
def setUpClass(cls):
def setUpClass(cls): # pylint:disable=g-missing-super-call
tf.keras.backend.clear_session()
def test_train_loop_then_eval_loop(self):
......@@ -99,6 +101,119 @@ class ModelLibTest(tf.test.TestCase):
model_dir=model_dir,
checkpoint_dir=model_dir,
train_steps=train_steps,
wait_interval=10,
wait_interval=1,
timeout=10,
**config_kwarg_overrides)
class SimpleModel(model.DetectionModel):
"""A model with a single weight vector."""
def __init__(self, num_classes=1):
super(SimpleModel, self).__init__(num_classes)
self.weight = tf.keras.backend.variable(np.ones(10), name='weight')
def postprocess(self, prediction_dict, true_image_shapes):
return {}
def updates(self):
return []
def restore_map(self, *args, **kwargs):
return {'model': self}
def preprocess(self, _):
return tf.zeros((1, 128, 128, 3)), tf.constant([[128, 128, 3]])
def provide_groundtruth(self, *args, **kwargs):
pass
def predict(self, pred_inputs, true_image_shapes):
return {'prediction':
tf.abs(tf.reduce_sum(self.weight) * tf.reduce_sum(pred_inputs))}
def loss(self, prediction_dict, _):
return {'loss': tf.reduce_sum(prediction_dict['prediction'])}
def regularization_losses(self):
return []
class ModelCheckpointTest(tf.test.TestCase):
"""Test for model checkpoint related functionality."""
def test_checkpoint_max_to_keep(self):
"""Test that only the most recent checkpoints are kept."""
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = SimpleModel()
hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
config_kwarg_overrides = _get_config_kwarg_overrides()
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
model_lib_v2.train_loop(
hparams, pipeline_config_path, model_dir=model_dir,
train_steps=20, checkpoint_every_n=2, checkpoint_max_to_keep=3,
**config_kwarg_overrides
)
ckpt_files = tf.io.gfile.glob(os.path.join(model_dir, 'ckpt-*.index'))
self.assertEqual(len(ckpt_files), 3,
'{} not of length 3.'.format(ckpt_files))
class IncompatibleModel(SimpleModel):
def restore_map(self, *args, **kwargs):
return {'weight': self.weight}
class CheckpointV2Test(tf.test.TestCase):
def setUp(self):
super(CheckpointV2Test, self).setUp()
self._model = SimpleModel()
tf.keras.backend.set_value(self._model.weight, np.ones(10) * 42)
ckpt = tf.train.Checkpoint(model=self._model)
self._test_dir = tf.test.get_temp_dir()
self._ckpt_path = ckpt.save(os.path.join(self._test_dir, 'ckpt'))
tf.keras.backend.set_value(self._model.weight, np.ones(10))
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
configs = config_util.merge_external_params_with_configs(
configs, kwargs_dict=_get_config_kwarg_overrides())
self._train_input_fn = inputs.create_train_input_fn(
configs['train_config'],
configs['train_input_config'],
configs['model'])
def test_restore_v2(self):
"""Test that restoring a v2 style checkpoint works."""
model_lib_v2.load_fine_tune_checkpoint(
self._model, self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2,
load_all_detection_checkpoint_vars=True,
input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True)
np.testing.assert_allclose(self._model.weight.numpy(), 42)
def test_restore_map_incompatible_error(self):
"""Test that restoring an incompatible restore map causes an error."""
with self.assertRaisesRegex(TypeError,
r'.*received a \(str -> ResourceVariable\).*'):
model_lib_v2.load_fine_tune_checkpoint(
IncompatibleModel(), self._ckpt_path, checkpoint_type='',
checkpoint_version=train_pb2.CheckpointVersion.V2,
load_all_detection_checkpoint_vars=True,
input_dataset=self._train_input_fn(),
unpad_groundtruth_tensors=True)
......@@ -29,6 +29,15 @@ import tensorflow as tf
from object_detection import model_hparams
from object_detection import model_lib
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib import tpu as contrib_tpu
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
tf.flags.DEFINE_bool('use_tpu', True, 'Use TPUs rather than plain CPUs')
# Cloud TPU Cluster Resolvers
......@@ -85,17 +94,15 @@ def main(unused_argv):
flags.mark_flag_as_required('pipeline_config_path')
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
tpu=[FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
contrib_cluster_resolver.TPUClusterResolver(
tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master()
config = tf.contrib.tpu.RunConfig(
config = contrib_tpu.RunConfig(
master=tpu_grpc_url,
evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir,
tpu_config=tf.contrib.tpu.TPUConfig(
tpu_config=contrib_tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_shards))
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -163,4 +164,4 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
insert_1x1_conv=True,
image_features=image_features)
return feature_maps.values()
return list(feature_maps.values())
......@@ -58,6 +58,7 @@ class FasterRCNNInceptionResnetV2KerasFeatureExtractor(
super(FasterRCNNInceptionResnetV2KerasFeatureExtractor, self).__init__(
is_training, first_stage_features_stride, batch_norm_trainable,
weight_decay)
self._variable_dict = {}
def preprocess(self, resized_inputs):
"""Faster R-CNN with Inception Resnet v2 preprocessing.
......@@ -105,9 +106,12 @@ class FasterRCNNInceptionResnetV2KerasFeatureExtractor(
include_top=False)
proposal_features = model.get_layer(
name='block17_20_ac').output
return tf.keras.Model(
keras_model = tf.keras.Model(
inputs=model.inputs,
outputs=proposal_features)
for variable in keras_model.variables:
self._variable_dict[variable.name[:-2]] = variable
return keras_model
def get_box_classifier_feature_extractor_model(self, name=None):
"""Returns a model that extracts second stage box classifier features.
......@@ -143,10 +147,13 @@ class FasterRCNNInceptionResnetV2KerasFeatureExtractor(
proposal_classifier_features = model.get_layer(
name='conv_7b_ac').output
return model_util.extract_submodel(
keras_model = model_util.extract_submodel(
model=model,
inputs=proposal_feature_maps,
outputs=proposal_classifier_features)
for variable in keras_model.variables:
self._variable_dict[variable.name[:-2]] = variable
return keras_model
def restore_from_classification_checkpoint_fn(
self,
......@@ -1071,9 +1078,16 @@ class FasterRCNNInceptionResnetV2KerasFeatureExtractor(
}
variables_to_restore = {}
if tf.executing_eagerly():
for key in self._variable_dict:
# variable.name includes ":0" at the end, but the names in the
# checkpoint do not have the suffix ":0". So, we strip it here.
var_name = keras_to_slim_name_mapping.get(key)
if var_name:
variables_to_restore[var_name] = self._variable_dict[key]
else:
for variable in variables_helper.get_global_variables_safely():
var_name = keras_to_slim_name_mapping.get(variable.op.name)
if var_name:
variables_to_restore[var_name] = variable
return variables_to_restore
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -20,6 +21,11 @@ Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le
https://arxiv.org/abs/1707.07012
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -18,6 +19,11 @@
Based on PNASNet model: https://arxiv.org/abs/1712.00559
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import slim as contrib_slim
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -23,8 +24,13 @@ Object detection feature extractors usually are built by stacking two components
Feature map generators build on the base feature extractors and produce a list
of final feature maps.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from object_detection.utils import ops
......@@ -221,8 +227,8 @@ class KerasMultiResolutionFeatureMaps(tf.keras.Model):
else:
if insert_1x1_conv:
layer_name = '{}_1_Conv2d_{}_1x1_{}'.format(
base_from_layer, index, depth_fn(layer_depth / 2))
net.append(tf.keras.layers.Conv2D(depth_fn(layer_depth / 2),
base_from_layer, index, depth_fn(layer_depth // 2))
net.append(tf.keras.layers.Conv2D(depth_fn(layer_depth // 2),
[1, 1],
padding='SAME',
strides=1,
......@@ -431,10 +437,10 @@ def multi_resolution_feature_maps(feature_map_layout, depth_multiplier,
intermediate_layer = pre_layer
if insert_1x1_conv:
layer_name = '{}_1_Conv2d_{}_1x1_{}'.format(
base_from_layer, index, depth_fn(layer_depth / 2))
base_from_layer, index, depth_fn(layer_depth // 2))
intermediate_layer = slim.conv2d(
pre_layer,
depth_fn(layer_depth / 2), [1, 1],
depth_fn(layer_depth // 2), [1, 1],
padding='SAME',
stride=1,
scope=layer_name)
......@@ -547,7 +553,7 @@ class KerasFpnTopDownFeatureMaps(tf.keras.Model):
self.top_layers.append(tf.keras.layers.Lambda(
clip_by_value, name='clip_by_value'))
for level in reversed(range(num_levels - 1)):
for level in reversed(list(range(num_levels - 1))):
# to generate residual from image features
residual_net = []
# to preprocess top_down (the image feature map from last layer)
......@@ -636,7 +642,7 @@ class KerasFpnTopDownFeatureMaps(tf.keras.Model):
output_feature_map_keys.append('top_down_%s' % image_features[-1][0])
num_levels = len(image_features)
for index, level in enumerate(reversed(range(num_levels - 1))):
for index, level in enumerate(reversed(list(range(num_levels - 1)))):
residual = image_features[level][1]
top_down = output_feature_maps_list[-1]
for layer in self.residual_blocks[index]:
......@@ -703,7 +709,7 @@ def fpn_top_down_feature_maps(image_features,
output_feature_map_keys.append(
'top_down_%s' % image_features[-1][0])
for level in reversed(range(num_levels - 1)):
for level in reversed(list(range(num_levels - 1))):
if use_native_resize_op:
with tf.name_scope('nearest_neighbor_upsampling'):
top_down_shape = shape_utils.combined_static_and_dynamic_shape(
......@@ -731,10 +737,11 @@ def fpn_top_down_feature_maps(image_features,
conv_op = functools.partial(slim.separable_conv2d, depth_multiplier=1)
else:
conv_op = slim.conv2d
pre_output = top_down
if use_explicit_padding:
top_down = ops.fixed_padding(top_down, kernel_size)
pre_output = ops.fixed_padding(pre_output, kernel_size)
output_feature_maps_list.append(conv_op(
top_down,
pre_output,
depth, [kernel_size, kernel_size],
scope='smoothing_%d' % (level + 1)))
output_feature_map_keys.append('top_down_%s' % image_features[level][0])
......@@ -778,7 +785,7 @@ def pooling_pyramid_feature_maps(base_feature_map_depth, num_layers,
"""
if len(image_features) != 1:
raise ValueError('image_features should be a dictionary of length 1.')
image_features = image_features[image_features.keys()[0]]
image_features = image_features[list(image_features.keys())[0]]
feature_map_keys = []
feature_maps = []
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment