Unverified Commit 420a7253 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Refactor tests for Object Detection API. (#8688)

Internal changes

--

PiperOrigin-RevId: 316837667
parent d0ef3913
syntax = "proto2";
package object_detection.protos;
import "object_detection/protos/image_resizer.proto";
import "object_detection/protos/losses.proto";
// Configuration for the CenterNet meta architecture from the "Objects as
// Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850
message CenterNet {
// Number of classes to predict.
optional int32 num_classes = 1;
// Feature extractor config.
optional CenterNetFeatureExtractor feature_extractor = 2;
// Image resizer for preprocessing the input image.
optional ImageResizer image_resizer = 3;
// Parameters which are related to object detection task.
message ObjectDetection {
// The original fields are moved to ObjectCenterParams or deleted.
reserved 2, 5, 6, 7;
// Weight of the task loss. The total loss of the model will be the
// summation of task losses weighted by the weights.
optional float task_loss_weight = 1 [default = 1.0];
// Weight for the offset localization loss.
optional float offset_loss_weight = 3 [default = 1.0];
// Weight for the height/width localization loss.
optional float scale_loss_weight = 4 [default = 0.1];
// Localization loss configuration for object scale and offset losses.
optional LocalizationLoss localization_loss = 8;
}
optional ObjectDetection object_detection_task = 4;
// Parameters related to object center prediction. This is required for both
// object detection and keypoint estimation tasks.
message ObjectCenterParams {
// Weight for the object center loss.
optional float object_center_loss_weight = 1 [default = 1.0];
// Classification loss configuration for object center loss.
optional ClassificationLoss classification_loss = 2;
// The initial bias value of the convlution kernel of the class heatmap
// prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1. See "Focal Loss for Dense Object Detection"
// at https://arxiv.org/abs/1708.02002.
optional float heatmap_bias_init = 3 [default = -2.19];
// The minimum IOU overlap boxes need to have to not be penalized.
optional float min_box_overlap_iou = 4 [default = 0.7];
// Maximum number of boxes to predict.
optional int32 max_box_predictions = 5 [default = 100];
// If set, loss is only computed for the labeled classes.
optional bool use_labeled_classes = 6 [default = false];
}
optional ObjectCenterParams object_center_params = 5;
// Path of the file that conatins the label map along with the keypoint
// information, including the keypoint indices, corresponding labels, and the
// corresponding class. The file should be the same one as used in the input
// pipeline. Note that a plain text of StringIntLabelMap proto is expected in
// this file.
// It is required only if the keypoint estimation task is specified.
optional string keypoint_label_map_path = 6;
// Parameters which are related to keypoint estimation task.
message KeypointEstimation {
// Name of the task, e.g. "human pose". Note that the task name should be
// unique to each keypoint task.
optional string task_name = 1;
// Weight of the task loss. The total loss of the model will be their
// summation of task losses weighted by the weights.
optional float task_loss_weight = 2 [default = 1.0];
// Loss configuration for keypoint heatmap, offset, regression losses. Note
// that the localization loss is used for offset/regression losses and
// classification loss is used for heatmap loss.
optional Loss loss = 3;
// The name of the class that contains the keypoints for this task. This is
// used to retrieve the corresponding keypoint indices from the label map.
// Note that this corresponds to the "name" field, not "display_name".
optional string keypoint_class_name = 4;
// The standard deviation of the Gaussian kernel used to generate the
// keypoint heatmap. The unit is the pixel in the output image. It is to
// provide the flexibility of using different sizes of Gaussian kernel for
// each keypoint class. Note that if provided, the keypoint standard
// deviations will be overridden by the specified values here, otherwise,
// the default value 5.0 will be used.
// TODO(yuhuic): Update the default value once we found the best value.
map<string, float> keypoint_label_to_std = 5;
// Loss weights corresponding to different heads.
optional float keypoint_regression_loss_weight = 6 [default = 1.0];
optional float keypoint_heatmap_loss_weight = 7 [default = 1.0];
optional float keypoint_offset_loss_weight = 8 [default = 1.0];
// The initial bias value of the convolution kernel of the keypoint heatmap
// prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1. See "Focal Loss for Dense Object Detection"
// at https://arxiv.org/abs/1708.02002.
optional float heatmap_bias_init = 9 [default = -2.19];
// The heatmap score threshold for a keypoint to become a valid candidate.
optional float keypoint_candidate_score_threshold = 10 [default = 0.1];
// The maximum number of candidates to retrieve for each keypoint.
optional int32 num_candidates_per_keypoint = 11 [default = 100];
// Max pool kernel size to use to pull off peak score locations in a
// neighborhood (independently for each keypoint types).
optional int32 peak_max_pool_kernel_size = 12 [default = 3];
// The default score to use for regressed keypoints that are not
// successfully snapped to a nearby candidate.
optional float unmatched_keypoint_score = 13 [default = 0.1];
// The multiplier to expand the bounding boxes (either the provided boxes or
// those which tightly cover the regressed keypoints). Note that new
// expanded box for an instance becomes the feasible search window for all
// associated keypoints.
optional float box_scale = 14 [default = 1.2];
// The scale parameter that multiplies the largest dimension of a bounding
// box. The resulting distance becomes a search radius for candidates in the
// vicinity of each regressed keypoint.
optional float candidate_search_scale = 15 [default = 0.3];
// One of ['min_distance', 'score_distance_ratio'] indicating how to select
// the keypoint candidate.
optional string candidate_ranking_mode = 16 [default = "min_distance"];
// The radius (in the unit of output pixel) around heatmap peak to assign
// the offset targets. If set 0, then the offset target will only be
// assigned to the heatmap peak (same behavior as the original paper).
optional int32 offset_peak_radius = 17 [default = 0];
// Indicates whether to assign offsets for each keypoint channel
// separately. If set False, the output offset target has the shape
// [batch_size, out_height, out_width, 2] (same behavior as the original
// paper). If set True, the output offset target has the shape [batch_size,
// out_height, out_width, 2 * num_keypoints] (recommended when the
// offset_peak_radius is not zero).
optional bool per_keypoint_offset = 18 [default = false];
}
repeated KeypointEstimation keypoint_estimation_task = 7;
// Parameters which are related to mask estimation task.
// Note: Currently, CenterNet supports a weak instance segmentation, where
// semantic segmentation masks are estimated, and then cropped based on
// bounding box detections. Therefore, it is possible for the same image
// pixel to be assigned to multiple instances.
message MaskEstimation {
// Weight of the task loss. The total loss of the model will be their
// summation of task losses weighted by the weights.
optional float task_loss_weight = 1 [default = 1.0];
// Classification loss configuration for segmentation loss.
optional ClassificationLoss classification_loss = 2;
// Each instance mask (one per detection) is cropped and resized (bilinear
// resampling) from the predicted segmentation feature map. After
// resampling, the masks are binarized with the provided score threshold.
optional int32 mask_height = 4 [default = 256];
optional int32 mask_width = 5 [default = 256];
optional float score_threshold = 6 [default = 0.5];
// The initial bias value of the convlution kernel of the class heatmap
// prediction head. -2.19 corresponds to predicting foreground with
// a probability of 0.1.
optional float heatmap_bias_init = 3 [default = -2.19];
}
optional MaskEstimation mask_estimation_task = 8;
}
message CenterNetFeatureExtractor {
optional string type = 1;
// Channel means to be subtracted from each image channel. If not specified,
// we use a default value of 0.
repeated float channel_means = 2;
// Channel standard deviations. Each channel will be normalized by dividing
// it by its standard deviation. If not specified, we use a default value
// of 1.
repeated float channel_stds = 3;
// If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors.
optional bool bgr_ordering = 4 [default = false];
}
...@@ -188,7 +188,7 @@ message Context { ...@@ -188,7 +188,7 @@ message Context {
// Next id: 4 // Next id: 4
// The maximum number of contextual features per-image, used for padding // The maximum number of contextual features per-image, used for padding
optional int32 max_num_context_features = 1 [default = 8500]; optional int32 max_num_context_features = 1 [default = 2000];
// The bottleneck feature dimension of the attention block. // The bottleneck feature dimension of the attention block.
optional int32 attention_bottleneck_dimension = 2 [default = 2048]; optional int32 attention_bottleneck_dimension = 2 [default = 2048];
......
...@@ -2,6 +2,7 @@ syntax = "proto2"; ...@@ -2,6 +2,7 @@ syntax = "proto2";
package object_detection.protos; package object_detection.protos;
import "object_detection/protos/center_net.proto";
import "object_detection/protos/faster_rcnn.proto"; import "object_detection/protos/faster_rcnn.proto";
import "object_detection/protos/ssd.proto"; import "object_detection/protos/ssd.proto";
...@@ -17,6 +18,7 @@ message DetectionModel { ...@@ -17,6 +18,7 @@ message DetectionModel {
// value to a function that builds your model. // value to a function that builds your model.
ExperimentalModel experimental_model = 3; ExperimentalModel experimental_model = 3;
CenterNet center_net = 4;
} }
} }
......
# SSDLite with MobileDet-GPU feature extractor.
# Reference: Xiong & Liu et al., https://arxiv.org/abs/2004.14525
# Trained on COCO, initialized from scratch.
#
# 5.07B MulAdds, 13.11M Parameters.
# Latencies are 11.0ms (fp32), 3.2ms (fp16) and 2.3ms (int8) on Jetson Xavier,
# optimized using TensorRT 7.1.
# Achieves 28.7 mAP on COCO14 minival dataset.
# Achieves 27.5 mAP on COCO17 val dataset.
#
# This config is TPU compatible.
model {
ssd {
inplace_batchnorm_update: true
freeze_batchnorm: false
num_classes: 90
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
use_matmul_gather: true
}
}
similarity_calculator {
iou_similarity {
}
}
encode_background_as_zeros: true
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
use_depthwise: true
box_code_size: 4
apply_sigmoid_to_scores: false
class_prediction_bias_init: -4.6
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
random_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.97,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_mobiledet_gpu'
min_depth: 16
depth_multiplier: 1.0
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.97,
epsilon: 0.001,
}
}
override_base_feature_extractor_hyperparams: false
}
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.75,
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
delta: 1.0
}
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
normalize_loc_loss_by_codesize: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
use_static_shapes: true
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 512
sync_replicas: true
startup_delay_steps: 0
replicas_to_aggregate: 32
num_steps: 400000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
optimizer {
momentum_optimizer: {
learning_rate: {
cosine_decay_learning_rate {
learning_rate_base: 0.8
total_steps: 400000
warmup_learning_rate: 0.13333
warmup_steps: 2000
}
}
momentum_optimizer_value: 0.9
}
use_moving_average: false
}
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
}
train_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
}
}
eval_config: {
metrics_set: "coco_detection_metrics"
use_moving_averages: false
num_examples: 8000
}
eval_input_reader: {
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010"
}
}
# Citation and license
The images and metadata in this folder come from the Snapshot Serengeti dataset,
and were accessed via [LILA.science](http://lila.science/datasets/snapshot-serengeti).
The images and species-level labels are described in more detail in the
associated manuscript:
```
Swanson AB, Kosmala M, Lintott CJ, Simpson RJ, Smith A, Packer C (2015)
Snapshot Serengeti, high-frequency annotated camera trap images of 40 mammalian
species in an African savanna. Scientific Data 2: 150026. (DOI) (bibtex)
```
Please cite this manuscript if you use this dataset.
This data set is released under the
[Community Data License Agreement (permissive variant)](https://cdla.io/permissive-1-0/).
{"images": [{"file_name": "models/research/object_detection/test_images/snapshot_serengeti/S1_E03_R3_PICT0038.jpeg", "frame_num": 0, "seq_num_frames": 2, "id": "S1/E03/E03_R3/S1_E03_R3_PICT0038", "height": 1536, "season": "S1", "date_captured": "2010-08-07 01:04:14", "width": 2048, "seq_id": "ASG0003041", "location": "E03"}, {"file_name": "models/research/object_detection/test_images/snapshot_serengeti/S1_E03_R3_PICT0039.jpeg", "frame_num": 1, "seq_num_frames": 2, "id": "S1/E03/E03_R3/S1_E03_R3_PICT0039", "height": 1536, "season": "S1", "date_captured": "2010-08-07 01:04:14", "width": 2048, "seq_id": "ASG0003041", "location": "E03"}, {"file_name": "models/research/object_detection/test_images/snapshot_serengeti/S1_E03_R3_PICT0040.jpeg", "frame_num": 0, "seq_num_frames": 2, "id": "S1/E03/E03_R3/S1_E03_R3_PICT0040", "height": 1536, "season": "S1", "date_captured": "2010-08-07 02:53:46", "width": 2048, "seq_id": "ASG0003042", "location": "E03"}, {"file_name": "models/research/object_detection/test_images/snapshot_serengeti/S1_E03_R3_PICT0041.jpeg", "frame_num": 1, "seq_num_frames": 2, "id": "S1/E03/E03_R3/S1_E03_R3_PICT0041", "height": 1536, "season": "S1", "date_captured": "2010-08-07 02:53:46", "width": 2048, "seq_id": "ASG0003042", "location": "E03"}], "categories": [{"name": "empty", "id": 0}, {"name": "human", "id": 1}, {"name": "gazelleGrants", "id": 2}, {"name": "reedbuck", "id": 3}, {"name": "dikDik", "id": 4}, {"name": "zebra", "id": 5}, {"name": "porcupine", "id": 6}, {"name": "gazelleThomsons", "id": 7}, {"name": "hyenaSpotted", "id": 8}, {"name": "warthog", "id": 9}, {"name": "impala", "id": 10}, {"name": "elephant", "id": 11}, {"name": "giraffe", "id": 12}, {"name": "mongoose", "id": 13}, {"name": "buffalo", "id": 14}, {"name": "hartebeest", "id": 15}, {"name": "guineaFowl", "id": 16}, {"name": "wildebeest", "id": 17}, {"name": "leopard", "id": 18}, {"name": "ostrich", "id": 19}, {"name": "lionFemale", "id": 20}, {"name": "koriBustard", "id": 21}, {"name": "otherBird", "id": 22}, {"name": "batEaredFox", "id": 23}, {"name": "bushbuck", "id": 24}, {"name": "jackal", "id": 25}, {"name": "cheetah", "id": 26}, {"name": "eland", "id": 27}, {"name": "aardwolf", "id": 28}, {"name": "hippopotamus", "id": 29}, {"name": "hyenaStriped", "id": 30}, {"name": "aardvark", "id": 31}, {"name": "hare", "id": 32}, {"name": "baboon", "id": 33}, {"name": "vervetMonkey", "id": 34}, {"name": "waterbuck", "id": 35}, {"name": "secretaryBird", "id": 36}, {"name": "serval", "id": 37}, {"name": "lionMale", "id": 38}, {"name": "topi", "id": 39}, {"name": "honeyBadger", "id": 40}, {"name": "rodents", "id": 41}, {"name": "wildcat", "id": 42}, {"name": "civet", "id": 43}, {"name": "genet", "id": 44}, {"name": "caracal", "id": 45}, {"name": "rhinoceros", "id": 46}, {"name": "reptiles", "id": 47}, {"name": "zorilla", "id": 48}], "annotations": [{"category_id": 29, "image_id": "S1/E03/E03_R3/S1_E03_R3_PICT0038", "bbox": [614.9233639240294, 476.2385201454182, 685.5741333961523, 374.18740868568574], "id": "0154T1541168895361"}, {"category_id": 29, "image_id": "S1/E03/E03_R3/S1_E03_R3_PICT0039", "bbox": [382.03749418258434, 471.005129814144, 756.2249028682752, 397.73766517639683], "id": "Lxtry1541168934504"}, {"category_id": 29, "image_id": "S1/E03/E03_R3/S1_E03_R3_PICT0040", "bbox": [786.9475708007834, 461.0229187011687, 749.0524291992166, 385.0301413536], "id": "Xmyih1541168739115"}, {"category_id": 29, "image_id": "S1/E03/E03_R3/S1_E03_R3_PICT0041", "bbox": [573.8866577148518, 453.0573425292903, 845.0, 398.9770812988263], "id": "ZllAa1541168769217"}]}
\ No newline at end of file
...@@ -19,12 +19,14 @@ from __future__ import division ...@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.tpu_exporters import export_saved_model_tpu_lib from object_detection.tpu_exporters import export_saved_model_tpu_lib
from object_detection.utils import tf_version
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -35,6 +37,7 @@ def get_path(path_suffix): ...@@ -35,6 +37,7 @@ def get_path(path_suffix):
path_suffix) path_suffix)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ExportSavedModelTPUTest(tf.test.TestCase, parameterized.TestCase): class ExportSavedModelTPUTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
......
...@@ -19,7 +19,7 @@ from __future__ import division ...@@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
from six.moves import range from six.moves import range
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -32,6 +32,7 @@ from object_detection.protos import model_pb2 ...@@ -32,6 +32,7 @@ from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2 from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2 from object_detection.protos import train_pb2
from object_detection.utils import config_util from object_detection.utils import config_util
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
...@@ -282,18 +283,22 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -282,18 +283,22 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertAlmostEqual(hparams.learning_rate * warmup_scale_factor, self.assertAlmostEqual(hparams.learning_rate * warmup_scale_factor,
cosine_lr.warmup_learning_rate) cosine_lr.warmup_learning_rate)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testRMSPropWithNewLearingRate(self): def testRMSPropWithNewLearingRate(self):
"""Tests new learning rates for RMSProp Optimizer.""" """Tests new learning rates for RMSProp Optimizer."""
self._assertOptimizerWithNewLearningRate("rms_prop_optimizer") self._assertOptimizerWithNewLearningRate("rms_prop_optimizer")
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testMomentumOptimizerWithNewLearningRate(self): def testMomentumOptimizerWithNewLearningRate(self):
"""Tests new learning rates for Momentum Optimizer.""" """Tests new learning rates for Momentum Optimizer."""
self._assertOptimizerWithNewLearningRate("momentum_optimizer") self._assertOptimizerWithNewLearningRate("momentum_optimizer")
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testAdamOptimizerWithNewLearningRate(self): def testAdamOptimizerWithNewLearningRate(self):
"""Tests new learning rates for Adam Optimizer.""" """Tests new learning rates for Adam Optimizer."""
self._assertOptimizerWithNewLearningRate("adam_optimizer") self._assertOptimizerWithNewLearningRate("adam_optimizer")
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testGenericConfigOverride(self): def testGenericConfigOverride(self):
"""Tests generic config overrides for all top-level configs.""" """Tests generic config overrides for all top-level configs."""
# Set one parameter for each of the top-level pipeline configs: # Set one parameter for each of the top-level pipeline configs:
...@@ -329,6 +334,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -329,6 +334,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertEqual(2, self.assertEqual(2,
configs["graph_rewriter_config"].quantization.weight_bits) configs["graph_rewriter_config"].quantization.weight_bits)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testNewBatchSize(self): def testNewBatchSize(self):
"""Tests that batch size is updated appropriately.""" """Tests that batch size is updated appropriately."""
original_batch_size = 2 original_batch_size = 2
...@@ -344,6 +350,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -344,6 +350,7 @@ class ConfigUtilTest(tf.test.TestCase):
new_batch_size = configs["train_config"].batch_size new_batch_size = configs["train_config"].batch_size
self.assertEqual(16, new_batch_size) self.assertEqual(16, new_batch_size)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testNewBatchSizeWithClipping(self): def testNewBatchSizeWithClipping(self):
"""Tests that batch size is clipped to 1 from below.""" """Tests that batch size is clipped to 1 from below."""
original_batch_size = 2 original_batch_size = 2
...@@ -359,6 +366,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -359,6 +366,7 @@ class ConfigUtilTest(tf.test.TestCase):
new_batch_size = configs["train_config"].batch_size new_batch_size = configs["train_config"].batch_size
self.assertEqual(1, new_batch_size) # Clipped to 1.0. self.assertEqual(1, new_batch_size) # Clipped to 1.0.
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testOverwriteBatchSizeWithKeyValue(self): def testOverwriteBatchSizeWithKeyValue(self):
"""Tests that batch size is overwritten based on key/value.""" """Tests that batch size is overwritten based on key/value."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -369,6 +377,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -369,6 +377,7 @@ class ConfigUtilTest(tf.test.TestCase):
new_batch_size = configs["train_config"].batch_size new_batch_size = configs["train_config"].batch_size
self.assertEqual(10, new_batch_size) self.assertEqual(10, new_batch_size)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testKeyValueOverrideBadKey(self): def testKeyValueOverrideBadKey(self):
"""Tests that overwriting with a bad key causes an exception.""" """Tests that overwriting with a bad key causes an exception."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -377,6 +386,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -377,6 +386,7 @@ class ConfigUtilTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config_util.merge_external_params_with_configs(configs, hparams) config_util.merge_external_params_with_configs(configs, hparams)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testOverwriteBatchSizeWithBadValueType(self): def testOverwriteBatchSizeWithBadValueType(self):
"""Tests that overwriting with a bad valuye type causes an exception.""" """Tests that overwriting with a bad valuye type causes an exception."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -387,6 +397,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -387,6 +397,7 @@ class ConfigUtilTest(tf.test.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
config_util.merge_external_params_with_configs(configs, hparams) config_util.merge_external_params_with_configs(configs, hparams)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testNewMomentumOptimizerValue(self): def testNewMomentumOptimizerValue(self):
"""Tests that new momentum value is updated appropriately.""" """Tests that new momentum value is updated appropriately."""
original_momentum_value = 0.4 original_momentum_value = 0.4
...@@ -404,6 +415,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -404,6 +415,7 @@ class ConfigUtilTest(tf.test.TestCase):
new_momentum_value = optimizer_config.momentum_optimizer_value new_momentum_value = optimizer_config.momentum_optimizer_value
self.assertAlmostEqual(1.0, new_momentum_value) # Clipped to 1.0. self.assertAlmostEqual(1.0, new_momentum_value) # Clipped to 1.0.
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testNewClassificationLocalizationWeightRatio(self): def testNewClassificationLocalizationWeightRatio(self):
"""Tests that the loss weight ratio is updated appropriately.""" """Tests that the loss weight ratio is updated appropriately."""
original_localization_weight = 0.1 original_localization_weight = 0.1
...@@ -426,6 +438,7 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -426,6 +438,7 @@ class ConfigUtilTest(tf.test.TestCase):
self.assertAlmostEqual(1.0, loss.localization_weight) self.assertAlmostEqual(1.0, loss.localization_weight)
self.assertAlmostEqual(new_weight_ratio, loss.classification_weight) self.assertAlmostEqual(new_weight_ratio, loss.classification_weight)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testNewFocalLossParameters(self): def testNewFocalLossParameters(self):
"""Tests that the loss weight ratio is updated appropriately.""" """Tests that the loss weight ratio is updated appropriately."""
original_alpha = 1.0 original_alpha = 1.0
......
...@@ -19,11 +19,14 @@ from __future__ import absolute_import ...@@ -19,11 +19,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import model_util from object_detection.utils import model_util
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ExtractSubmodelUtilTest(tf.test.TestCase): class ExtractSubmodelUtilTest(tf.test.TestCase):
def test_simple_model(self): def test_simple_model(self):
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import six import six
...@@ -26,6 +28,7 @@ import tensorflow.compat.v1 as tf ...@@ -26,6 +28,7 @@ import tensorflow.compat.v1 as tf
from object_detection import eval_util from object_detection import eval_util
from object_detection.core import standard_fields from object_detection.core import standard_fields
from object_detection.utils import object_detection_evaluation from object_detection.utils import object_detection_evaluation
from object_detection.utils import tf_version
class OpenImagesV2EvaluationTest(tf.test.TestCase): class OpenImagesV2EvaluationTest(tf.test.TestCase):
...@@ -970,6 +973,8 @@ class ObjectDetectionEvaluationTest(tf.test.TestCase): ...@@ -970,6 +973,8 @@ class ObjectDetectionEvaluationTest(tf.test.TestCase):
self.assertAlmostEqual(copy_mean_corloc, mean_corloc) self.assertAlmostEqual(copy_mean_corloc, mean_corloc)
@unittest.skipIf(tf_version.is_tf2(), 'Eval Metrics ops are supported in TF1.X '
'only.')
class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase): class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
......
...@@ -268,7 +268,7 @@ def padded_one_hot_encoding(indices, depth, left_pad): ...@@ -268,7 +268,7 @@ def padded_one_hot_encoding(indices, depth, left_pad):
on_value=1, off_value=0), tf.float32) on_value=1, off_value=0), tf.float32)
return tf.pad(one_hot, [[0, 0], [left_pad, 0]], mode='CONSTANT') return tf.pad(one_hot, [[0, 0], [left_pad, 0]], mode='CONSTANT')
result = tf.cond(tf.greater(tf.size(indices), 0), one_hot_and_pad, result = tf.cond(tf.greater(tf.size(indices), 0), one_hot_and_pad,
lambda: tf.zeros((depth + left_pad, 0))) lambda: tf.zeros((tf.size(indices), depth + left_pad)))
return tf.reshape(result, [-1, depth + left_pad]) return tf.reshape(result, [-1, depth + left_pad])
......
...@@ -196,8 +196,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase): ...@@ -196,8 +196,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase):
[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1]], np.float32) [0, 0, 0, 0, 0, 1]], np.float32)
# Executing on CPU only because output shape is not constant. out_one_hot_tensor = self.execute(graph_fn, [])
out_one_hot_tensor = self.execute_cpu(graph_fn, [])
self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10, self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10,
atol=1e-10) atol=1e-10)
...@@ -212,8 +211,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase): ...@@ -212,8 +211,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase):
[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1]], np.float32) [0, 0, 0, 0, 0, 0, 1]], np.float32)
# Executing on CPU only because output shape is not constant. out_one_hot_tensor = self.execute(graph_fn, [])
out_one_hot_tensor = self.execute_cpu(graph_fn, [])
self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10, self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10,
atol=1e-10) atol=1e-10)
...@@ -229,8 +227,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase): ...@@ -229,8 +227,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase):
[0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1]], np.float32) [0, 0, 0, 0, 0, 0, 0, 0, 1]], np.float32)
# executing on CPU only because output shape is not constant. out_one_hot_tensor = self.execute(graph_fn, [])
out_one_hot_tensor = self.execute_cpu(graph_fn, [])
self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10, self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10,
atol=1e-10) atol=1e-10)
...@@ -246,8 +243,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase): ...@@ -246,8 +243,7 @@ class OpsTestPaddedOneHotEncoding(test_case.TestCase):
return one_hot_tensor return one_hot_tensor
expected_tensor = np.zeros((0, depth + pad)) expected_tensor = np.zeros((0, depth + pad))
# executing on CPU only because output shape is not constant. out_one_hot_tensor = self.execute(graph_fn, [])
out_one_hot_tensor = self.execute_cpu(graph_fn, [])
self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10, self.assertAllClose(out_one_hot_tensor, expected_tensor, rtol=1e-10,
atol=1e-10) atol=1e-10)
......
...@@ -118,12 +118,17 @@ def compute_floor_offsets_with_indices(y_source, ...@@ -118,12 +118,17 @@ def compute_floor_offsets_with_indices(y_source,
they were put on the grids) to target coordinates. Note that the input they were put on the grids) to target coordinates. Note that the input
coordinates should be the "absolute" coordinates in terms of the output image coordinates should be the "absolute" coordinates in terms of the output image
dimensions as opposed to the normalized coordinates (i.e. values in [0, 1]). dimensions as opposed to the normalized coordinates (i.e. values in [0, 1]).
If the input y and x source have the second dimension (representing the
neighboring pixels), then the offsets are computed from each of the
neighboring pixels to their corresponding target (first dimension).
Args: Args:
y_source: A tensor with shape [num_points] representing the absolute y_source: A tensor with shape [num_points] (or [num_points, num_neighbors])
y-coordinates (in the output image space) of the source points. representing the absolute y-coordinates (in the output image space) of the
x_source: A tensor with shape [num_points] representing the absolute source points.
x-coordinates (in the output image space) of the source points. x_source: A tensor with shape [num_points] (or [num_points, num_neighbors])
representing the absolute x-coordinates (in the output image space) of the
source points.
y_target: A tensor with shape [num_points] representing the absolute y_target: A tensor with shape [num_points] representing the absolute
y-coordinates (in the output image space) of the target points. If not y-coordinates (in the output image space) of the target points. If not
provided, then y_source is used as the targets. provided, then y_source is used as the targets.
...@@ -133,18 +138,33 @@ def compute_floor_offsets_with_indices(y_source, ...@@ -133,18 +138,33 @@ def compute_floor_offsets_with_indices(y_source,
Returns: Returns:
A tuple of two tensors: A tuple of two tensors:
offsets: A tensor with shape [num_points, 2] representing the offsets of offsets: A tensor with shape [num_points, 2] (or
each input point. [num_points, num_neighbors, 2]) representing the offsets of each input
indices: A tensor with shape [num_points, 2] representing the indices of point.
where the offsets should be retrieved in the output image dimension indices: A tensor with shape [num_points, 2] (or
space. [num_points, num_neighbors, 2]) representing the indices of where the
offsets should be retrieved in the output image dimension space.
Raise:
ValueError: source and target shapes have unexpected values.
""" """
y_source_floored = tf.floor(y_source) y_source_floored = tf.floor(y_source)
x_source_floored = tf.floor(x_source) x_source_floored = tf.floor(x_source)
if y_target is None:
source_shape = shape_utils.combined_static_and_dynamic_shape(y_source)
if y_target is None and x_target is None:
y_target = y_source y_target = y_source
if x_target is None:
x_target = x_source x_target = x_source
else:
target_shape = shape_utils.combined_static_and_dynamic_shape(y_target)
if len(source_shape) == 2 and len(target_shape) == 1:
_, num_neighbors = source_shape
y_target = tf.tile(
tf.expand_dims(y_target, -1), multiples=[1, num_neighbors])
x_target = tf.tile(
tf.expand_dims(x_target, -1), multiples=[1, num_neighbors])
elif source_shape != target_shape:
raise ValueError('Inconsistent source and target shape.')
y_offset = y_target - y_source_floored y_offset = y_target - y_source_floored
x_offset = x_target - x_source_floored x_offset = x_target - x_source_floored
...@@ -152,9 +172,8 @@ def compute_floor_offsets_with_indices(y_source, ...@@ -152,9 +172,8 @@ def compute_floor_offsets_with_indices(y_source,
y_source_indices = tf.cast(y_source_floored, tf.int32) y_source_indices = tf.cast(y_source_floored, tf.int32)
x_source_indices = tf.cast(x_source_floored, tf.int32) x_source_indices = tf.cast(x_source_floored, tf.int32)
indices = tf.stack([y_source_indices, x_source_indices], axis=1) indices = tf.stack([y_source_indices, x_source_indices], axis=-1)
offsets = tf.stack([y_offset, x_offset], axis=1) offsets = tf.stack([y_offset, x_offset], axis=-1)
return offsets, indices return offsets, indices
...@@ -231,6 +250,12 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout): ...@@ -231,6 +250,12 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout):
A float tensor with shape [height, width] where all values within the A float tensor with shape [height, width] where all values within the
regions of the blackout boxes are 0.0 and 1.0 else where. regions of the blackout boxes are 0.0 and 1.0 else where.
""" """
num_instances, _ = shape_utils.combined_static_and_dynamic_shape(boxes)
# If no annotation instance is provided, return all ones (instead of
# unexpected values) to avoid NaN loss value.
if num_instances == 0:
return tf.ones([height, width], dtype=tf.float32)
(y_grid, x_grid) = image_shape_to_grids(height, width) (y_grid, x_grid) = image_shape_to_grids(height, width)
y_grid = tf.expand_dims(y_grid, axis=0) y_grid = tf.expand_dims(y_grid, axis=0)
x_grid = tf.expand_dims(x_grid, axis=0) x_grid = tf.expand_dims(x_grid, axis=0)
...@@ -257,3 +282,72 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout): ...@@ -257,3 +282,72 @@ def blackout_pixel_weights_by_box_regions(height, width, boxes, blackout):
out_boxes = tf.reduce_max(selected_in_boxes, axis=0) out_boxes = tf.reduce_max(selected_in_boxes, axis=0)
out_boxes = tf.ones_like(out_boxes) - out_boxes out_boxes = tf.ones_like(out_boxes) - out_boxes
return out_boxes return out_boxes
def _get_yx_indices_offset_by_radius(radius):
"""Gets the y and x index offsets that are within the radius."""
y_offsets = []
x_offsets = []
for y_offset in range(-radius, radius + 1, 1):
for x_offset in range(-radius, radius + 1, 1):
if x_offset ** 2 + y_offset ** 2 <= radius ** 2:
y_offsets.append(y_offset)
x_offsets.append(x_offset)
return (tf.constant(y_offsets, dtype=tf.float32),
tf.constant(x_offsets, dtype=tf.float32))
def get_surrounding_grids(height, width, y_coordinates, x_coordinates, radius):
"""Gets the indices of the surrounding pixels of the input y, x coordinates.
This function returns the pixel indices corresponding to the (floor of the)
input coordinates and their surrounding pixels within the radius. If the
radius is set to 0, then only the pixels that correspond to the floor of the
coordinates will be returned. If the radius is larger than 0, then all of the
pixels within the radius of the "floor pixels" will also be returned. For
example, if the input coorindate is [2.1, 3.5] and radius is 1, then the five
pixel indices will be returned: [2, 3], [1, 3], [2, 2], [2, 4], [3, 3]. Also,
if the surrounding pixels are outside of valid image region, then the returned
pixel indices will be [0, 0] and its corresponding "valid" value will be
False.
Args:
height: int, the height of the output image.
width: int, the width of the output image.
y_coordinates: A tensor with shape [num_points] representing the absolute
y-coordinates (in the output image space) of the points.
x_coordinates: A tensor with shape [num_points] representing the absolute
x-coordinates (in the output image space) of the points.
radius: int, the radius of the neighboring pixels to be considered and
returned. If set to 0, then only the pixel indices corresponding to the
floor of the input coordinates will be returned.
Returns:
A tuple of three tensors:
y_indices: A [num_points, num_neighbors] float tensor representing the
pixel y indices corresponding to the input points within radius. The
"num_neighbors" is determined by the size of the radius.
x_indices: A [num_points, num_neighbors] float tensor representing the
pixel x indices corresponding to the input points within radius. The
"num_neighbors" is determined by the size of the radius.
valid: A [num_points, num_neighbors] boolean tensor representing whether
each returned index is in valid image region or not.
"""
# Floored y, x: [num_points, 1].
y_center = tf.expand_dims(tf.math.floor(y_coordinates), axis=-1)
x_center = tf.expand_dims(tf.math.floor(x_coordinates), axis=-1)
y_offsets, x_offsets = _get_yx_indices_offset_by_radius(radius)
# Indices offsets: [1, num_neighbors].
y_offsets = tf.expand_dims(y_offsets, axis=0)
x_offsets = tf.expand_dims(x_offsets, axis=0)
# Floor + offsets: [num_points, num_neighbors].
y_output = y_center + y_offsets
x_output = x_center + x_offsets
default_output = tf.zeros_like(y_output)
valid = tf.logical_and(
tf.logical_and(x_output >= 0, x_output < width),
tf.logical_and(y_output >= 0, y_output < height))
y_output = tf.where(valid, y_output, default_output)
x_output = tf.where(valid, x_output, default_output)
return (y_output, x_output, valid)
...@@ -87,8 +87,32 @@ class TargetUtilTest(test_case.TestCase): ...@@ -87,8 +87,32 @@ class TargetUtilTest(test_case.TestCase):
np.testing.assert_array_almost_equal(offsets, np.testing.assert_array_almost_equal(offsets,
np.array([[1.1, -0.8], [0.1, 0.5]])) np.array([[1.1, -0.8], [0.1, 0.5]]))
np.testing.assert_array_almost_equal(indices, np.testing.assert_array_almost_equal(indices, np.array([[1, 2], [0, 4]]))
np.array([[1, 2], [0, 4]]))
def test_compute_floor_offsets_with_indices_multisources(self):
def graph_fn():
y_source = tf.constant([[1.0, 0.0], [2.0, 3.0]], dtype=tf.float32)
x_source = tf.constant([[2.0, 4.0], [3.0, 3.0]], dtype=tf.float32)
y_target = tf.constant([2.1, 0.1], dtype=tf.float32)
x_target = tf.constant([1.2, 4.5], dtype=tf.float32)
(offsets, indices) = ta_utils.compute_floor_offsets_with_indices(
y_source, x_source, y_target, x_target)
return offsets, indices
offsets, indices = self.execute(graph_fn, [])
# Offset from the first source to target.
np.testing.assert_array_almost_equal(offsets[:, 0, :],
np.array([[1.1, -0.8], [-1.9, 1.5]]))
# Offset from the second source to target.
np.testing.assert_array_almost_equal(offsets[:, 1, :],
np.array([[2.1, -2.8], [-2.9, 1.5]]))
# Indices from the first source to target.
np.testing.assert_array_almost_equal(indices[:, 0, :],
np.array([[1, 2], [2, 3]]))
# Indices from the second source to target.
np.testing.assert_array_almost_equal(indices[:, 1, :],
np.array([[0, 4], [3, 3]]))
def test_get_valid_keypoints_mask(self): def test_get_valid_keypoints_mask(self):
...@@ -174,6 +198,44 @@ class TargetUtilTest(test_case.TestCase): ...@@ -174,6 +198,44 @@ class TargetUtilTest(test_case.TestCase):
# 20 * 10 - 6 * 6 - 3 * 7 = 143.0 # 20 * 10 - 6 * 6 - 3 * 7 = 143.0
self.assertAlmostEqual(np.sum(output), 143.0) self.assertAlmostEqual(np.sum(output), 143.0)
def test_blackout_pixel_weights_by_box_regions_zero_instance(self):
def graph_fn():
boxes = tf.zeros([0, 4], dtype=tf.float32)
blackout = tf.zeros([0], dtype=tf.bool)
blackout_pixel_weights_by_box_regions = tf.function(
ta_utils.blackout_pixel_weights_by_box_regions)
output = blackout_pixel_weights_by_box_regions(10, 20, boxes, blackout)
return output
output = self.execute(graph_fn, [])
# The output should be all 1s since there's no annotation provided.
np.testing.assert_array_equal(output, np.ones([10, 20], dtype=np.float32))
def test_get_surrounding_grids(self):
def graph_fn():
y_coordinates = tf.constant([0.5], dtype=tf.float32)
x_coordinates = tf.constant([4.5], dtype=tf.float32)
output = ta_utils.get_surrounding_grids(
height=3,
width=5,
y_coordinates=y_coordinates,
x_coordinates=x_coordinates,
radius=1)
return output
y_indices, x_indices, valid = self.execute(graph_fn, [])
# Five neighboring indices: [-1, 4] (out of bound), [0, 3], [0, 4],
# [0, 5] (out of bound), [1, 4].
np.testing.assert_array_almost_equal(
y_indices,
np.array([[0.0, 0.0, 0.0, 0.0, 1.0]]))
np.testing.assert_array_almost_equal(
x_indices,
np.array([[0.0, 3.0, 4.0, 0.0, 4.0]]))
self.assertAllEqual(valid, [[False, True, True, False, True]])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -271,3 +271,19 @@ class GraphContextOrNone(object): ...@@ -271,3 +271,19 @@ class GraphContextOrNone(object):
return False return False
else: else:
return self.graph.__exit__(ttype, value, traceback) return self.graph.__exit__(ttype, value, traceback)
def image_with_dynamic_shape(height, width, channels):
"""Returns a single image with dynamic shape."""
h = tf.random.uniform([], minval=height, maxval=height+1, dtype=tf.int32)
w = tf.random.uniform([], minval=width, maxval=width+1, dtype=tf.int32)
image = tf.random.uniform([h, w, channels])
return image
def keypoints_with_dynamic_shape(num_instances, num_keypoints, num_coordinates):
"""Returns keypoints with dynamic shape."""
n = tf.random.uniform([], minval=num_instances, maxval=num_instances+1,
dtype=tf.int32)
keypoints = tf.random.uniform([n, num_keypoints, num_coordinates])
return keypoints
...@@ -47,8 +47,6 @@ def filter_variables(variables, filter_regex_list, invert=False): ...@@ -47,8 +47,6 @@ def filter_variables(variables, filter_regex_list, invert=False):
Returns: Returns:
a list of filtered variables. a list of filtered variables.
""" """
if tf.executing_eagerly():
raise ValueError('Accessing variables is not supported in eager mode.')
kept_vars = [] kept_vars = []
variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre]) variables_to_ignore_patterns = list([fre for fre in filter_regex_list if fre])
for var in variables: for var in variables:
...@@ -74,8 +72,6 @@ def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): ...@@ -74,8 +72,6 @@ def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
Returns: Returns:
grads_and_vars: A list of gradient to variable pairs (tuples). grads_and_vars: A list of gradient to variable pairs (tuples).
""" """
if tf.executing_eagerly():
raise ValueError('Accessing variables is not supported in eager mode.')
variables = [pair[1] for pair in grads_and_vars] variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True) matching_vars = filter_variables(variables, regex_list, invert=True)
for var in matching_vars: for var in matching_vars:
...@@ -97,8 +93,6 @@ def freeze_gradients_matching_regex(grads_and_vars, regex_list): ...@@ -97,8 +93,6 @@ def freeze_gradients_matching_regex(grads_and_vars, regex_list):
grads_and_vars: A list of gradient to variable pairs (tuples) that do not grads_and_vars: A list of gradient to variable pairs (tuples) that do not
contain the variables and gradients matching the regex. contain the variables and gradients matching the regex.
""" """
if tf.executing_eagerly():
raise ValueError('Accessing variables is not supported in eager mode.')
variables = [pair[1] for pair in grads_and_vars] variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True) matching_vars = filter_variables(variables, regex_list, invert=True)
kept_grads_and_vars = [pair for pair in grads_and_vars kept_grads_and_vars = [pair for pair in grads_and_vars
...@@ -129,8 +123,6 @@ def get_variables_available_in_checkpoint(variables, ...@@ -129,8 +123,6 @@ def get_variables_available_in_checkpoint(variables,
Raises: Raises:
ValueError: if `variables` is not a list or dict. ValueError: if `variables` is not a list or dict.
""" """
if tf.executing_eagerly():
raise ValueError('Accessing variables is not supported in eager mode.')
if isinstance(variables, list): if isinstance(variables, list):
variable_names_map = {} variable_names_map = {}
for variable in variables: for variable in variables:
...@@ -178,8 +170,6 @@ def get_global_variables_safely(): ...@@ -178,8 +170,6 @@ def get_global_variables_safely():
Returns: Returns:
The result of tf.global_variables() The result of tf.global_variables()
""" """
if tf.executing_eagerly():
raise ValueError('Accessing variables is not supported in eager mode.')
with tf.init_scope(): with tf.init_scope():
if tf.executing_eagerly(): if tf.executing_eagerly():
raise ValueError("Global variables collection is not tracked when " raise ValueError("Global variables collection is not tracked when "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment