Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
213a9649
Commit
213a9649
authored
Jun 02, 2021
by
Ronny Votel
Committed by
TF Object Detection Team
Jun 02, 2021
Browse files
Introducing groundtruth instance mask weights.
PiperOrigin-RevId: 377096964
parent
0b9a2a74
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
182 additions
and
8 deletions
+182
-8
research/object_detection/core/preprocessor.py
research/object_detection/core/preprocessor.py
+29
-2
research/object_detection/core/preprocessor_test.py
research/object_detection/core/preprocessor_test.py
+40
-3
research/object_detection/core/standard_fields.py
research/object_detection/core/standard_fields.py
+2
-0
research/object_detection/data_decoders/tf_example_decoder.py
...arch/object_detection/data_decoders/tf_example_decoder.py
+24
-0
research/object_detection/data_decoders/tf_example_decoder_test.py
...object_detection/data_decoders/tf_example_decoder_test.py
+68
-0
research/object_detection/inputs.py
research/object_detection/inputs.py
+11
-0
research/object_detection/inputs_test.py
research/object_detection/inputs_test.py
+8
-3
No files found.
research/object_detection/core/preprocessor.py
View file @
213a9649
...
...
@@ -1414,6 +1414,7 @@ def _strict_random_crop_image(image,
label_confidences
=
None
,
multiclass_scores
=
None
,
masks
=
None
,
mask_weights
=
None
,
keypoints
=
None
,
keypoint_visibilities
=
None
,
densepose_num_points
=
None
,
...
...
@@ -1451,6 +1452,8 @@ def _strict_random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
with instance masks weights.
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates.
...
...
@@ -1488,7 +1491,7 @@ def _strict_random_crop_image(image,
Boxes are in normalized form.
labels: new labels.
If label_weights, multiclass_scores, masks, keypoints,
If label_weights, multiclass_scores, masks,
mask_weights,
keypoints,
keypoint_visibilities, densepose_num_points, densepose_part_ids, or
densepose_surface_coords is not None, the function also returns:
label_weights: rank 1 float32 tensor with shape [num_instances].
...
...
@@ -1496,6 +1499,8 @@ def _strict_random_crop_image(image,
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances] with mask
weights.
keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape
...
...
@@ -1605,6 +1610,12 @@ def _strict_random_crop_image(image,
0
]:
im_box_end
[
0
],
im_box_begin
[
1
]:
im_box_end
[
1
]]
result
.
append
(
new_masks
)
if
mask_weights
is
not
None
:
mask_weights_inside_window
=
tf
.
gather
(
mask_weights
,
inside_window_ids
)
mask_weights_completely_inside_window
=
tf
.
gather
(
mask_weights_inside_window
,
keep_ids
)
result
.
append
(
mask_weights_completely_inside_window
)
if
keypoints
is
not
None
:
keypoints_of_boxes_inside_window
=
tf
.
gather
(
keypoints
,
inside_window_ids
)
keypoints_of_boxes_completely_inside_window
=
tf
.
gather
(
...
...
@@ -1654,6 +1665,7 @@ def random_crop_image(image,
label_confidences
=
None
,
multiclass_scores
=
None
,
masks
=
None
,
mask_weights
=
None
,
keypoints
=
None
,
keypoint_visibilities
=
None
,
densepose_num_points
=
None
,
...
...
@@ -1701,6 +1713,8 @@ def random_crop_image(image,
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`.
mask_weights: (optional) rank 1 float32 tensor with shape [num_instances]
containing weights for each instance mask.
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates.
...
...
@@ -1751,6 +1765,7 @@ def random_crop_image(image,
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks.
mask_weights: rank 1 float32 tensor with shape [num_instances].
keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]
keypoint_visibilities: rank 2 bool tensor with shape
...
...
@@ -1771,6 +1786,7 @@ def random_crop_image(image,
label_confidences
=
label_confidences
,
multiclass_scores
=
multiclass_scores
,
masks
=
masks
,
mask_weights
=
mask_weights
,
keypoints
=
keypoints
,
keypoint_visibilities
=
keypoint_visibilities
,
densepose_num_points
=
densepose_num_points
,
...
...
@@ -1803,6 +1819,8 @@ def random_crop_image(image,
outputs
.
append
(
multiclass_scores
)
if
masks
is
not
None
:
outputs
.
append
(
masks
)
if
mask_weights
is
not
None
:
outputs
.
append
(
mask_weights
)
if
keypoints
is
not
None
:
outputs
.
append
(
keypoints
)
if
keypoint_visibilities
is
not
None
:
...
...
@@ -4388,6 +4406,7 @@ def get_default_func_arg_map(include_label_weights=True,
include_label_confidences
=
False
,
include_multiclass_scores
=
False
,
include_instance_masks
=
False
,
include_instance_mask_weights
=
False
,
include_keypoints
=
False
,
include_keypoint_visibilities
=
False
,
include_dense_pose
=
False
,
...
...
@@ -4403,6 +4422,8 @@ def get_default_func_arg_map(include_label_weights=True,
multiclass scores, too.
include_instance_masks: If True, preprocessing functions will modify the
instance masks, too.
include_instance_mask_weights: If True, preprocessing functions will modify
the instance mask weights.
include_keypoints: If True, preprocessing functions will modify the
keypoints, too.
include_keypoint_visibilities: If True, preprocessing functions will modify
...
...
@@ -4434,6 +4455,11 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_instance_masks
=
(
fields
.
InputDataFields
.
groundtruth_instance_masks
)
groundtruth_instance_mask_weights
=
None
if
include_instance_mask_weights
:
groundtruth_instance_mask_weights
=
(
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
)
groundtruth_keypoints
=
None
if
include_keypoints
:
groundtruth_keypoints
=
fields
.
InputDataFields
.
groundtruth_keypoints
...
...
@@ -4503,7 +4529,8 @@ def get_default_func_arg_map(include_label_weights=True,
fields
.
InputDataFields
.
groundtruth_boxes
,
fields
.
InputDataFields
.
groundtruth_classes
,
groundtruth_label_weights
,
groundtruth_label_confidences
,
multiclass_scores
,
groundtruth_instance_masks
,
groundtruth_keypoints
,
multiclass_scores
,
groundtruth_instance_masks
,
groundtruth_instance_mask_weights
,
groundtruth_keypoints
,
groundtruth_keypoint_visibilities
,
groundtruth_dp_num_points
,
groundtruth_dp_part_ids
,
groundtruth_dp_surface_coords
),
random_pad_image
:
...
...
research/object_detection/core/preprocessor_test.py
View file @
213a9649
...
...
@@ -1894,6 +1894,37 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self
.
assertAllClose
(
new_boxes
.
flatten
(),
expected_boxes
.
flatten
())
def
testStrictRandomCropImageWithMaskWeights
(
self
):
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()[
0
]
boxes
=
self
.
createTestBoxes
()
labels
=
self
.
createTestLabels
()
weights
=
self
.
createTestGroundtruthWeights
()
masks
=
tf
.
random_uniform
([
2
,
200
,
400
],
dtype
=
tf
.
float32
)
mask_weights
=
tf
.
constant
([
1.0
,
0.0
],
dtype
=
tf
.
float32
)
with
mock
.
patch
.
object
(
tf
.
image
,
'sample_distorted_bounding_box'
)
as
mock_sample_distorted_bounding_box
:
mock_sample_distorted_bounding_box
.
return_value
=
(
tf
.
constant
([
6
,
143
,
0
],
dtype
=
tf
.
int32
),
tf
.
constant
([
190
,
237
,
-
1
],
dtype
=
tf
.
int32
),
tf
.
constant
([[[
0.03
,
0.3575
,
0.98
,
0.95
]]],
dtype
=
tf
.
float32
))
results
=
preprocessor
.
_strict_random_crop_image
(
image
,
boxes
,
labels
,
weights
,
masks
=
masks
,
mask_weights
=
mask_weights
)
return
results
(
new_image
,
new_boxes
,
_
,
_
,
new_masks
,
new_mask_weights
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
(
[[
0.0
,
0.0
,
0.75789469
,
1.0
],
[
0.23157893
,
0.24050637
,
0.75789469
,
1.0
]],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
new_image
.
shape
,
[
190
,
237
,
3
])
self
.
assertAllEqual
(
new_masks
.
shape
,
[
2
,
190
,
237
])
self
.
assertAllClose
(
new_mask_weights
,
[
1.0
,
0.0
])
self
.
assertAllClose
(
new_boxes
.
flatten
(),
expected_boxes
.
flatten
())
def
testStrictRandomCropImageWithKeypoints
(
self
):
def
graph_fn
():
image
=
self
.
createColorfulTestImage
()[
0
]
...
...
@@ -1947,6 +1978,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
labels
=
self
.
createTestLabels
()
weights
=
self
.
createTestGroundtruthWeights
()
masks
=
tf
.
random_uniform
([
2
,
200
,
400
],
dtype
=
tf
.
float32
)
mask_weights
=
tf
.
constant
([
1.0
,
0.0
],
dtype
=
tf
.
float32
)
tensor_dict
=
{
fields
.
InputDataFields
.
image
:
image
,
...
...
@@ -1954,10 +1986,12 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields
.
InputDataFields
.
groundtruth_classes
:
labels
,
fields
.
InputDataFields
.
groundtruth_weights
:
weights
,
fields
.
InputDataFields
.
groundtruth_instance_masks
:
masks
,
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
:
mask_weights
}
preprocessor_arg_map
=
preprocessor
.
get_default_func_arg_map
(
include_instance_masks
=
True
)
include_instance_masks
=
True
,
include_instance_mask_weights
=
True
)
preprocessing_options
=
[(
preprocessor
.
random_crop_image
,
{})]
...
...
@@ -1980,16 +2014,19 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
fields
.
InputDataFields
.
groundtruth_classes
]
distorted_masks
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
distorted_mask_weights
=
distorted_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
]
return
[
distorted_image
,
distorted_boxes
,
distorted_labels
,
distorted_masks
]
distorted_masks
,
distorted_mask_weights
]
(
distorted_image_
,
distorted_boxes_
,
distorted_labels_
,
distorted_masks_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
distorted_masks_
,
distorted_mask_weights_
)
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_boxes
=
np
.
array
([
[
0.0
,
0.0
,
0.75789469
,
1.0
],
[
0.23157893
,
0.24050637
,
0.75789469
,
1.0
],
],
dtype
=
np
.
float32
)
self
.
assertAllEqual
(
distorted_image_
.
shape
,
[
1
,
190
,
237
,
3
])
self
.
assertAllEqual
(
distorted_masks_
.
shape
,
[
2
,
190
,
237
])
self
.
assertAllClose
(
distorted_mask_weights_
,
[
1.0
,
0.0
])
self
.
assertAllEqual
(
distorted_labels_
,
[
1
,
2
])
self
.
assertAllClose
(
distorted_boxes_
.
flatten
(),
expected_boxes
.
flatten
())
...
...
research/object_detection/core/standard_fields.py
View file @
213a9649
...
...
@@ -64,6 +64,7 @@ class InputDataFields(object):
proposal_boxes: coordinates of object proposal boxes.
proposal_objectness: objectness score of each proposal.
groundtruth_instance_masks: ground truth instance masks.
groundtruth_instance_mask_weights: ground truth instance masks weights.
groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints.
...
...
@@ -122,6 +123,7 @@ class InputDataFields(object):
proposal_boxes
=
'proposal_boxes'
proposal_objectness
=
'proposal_objectness'
groundtruth_instance_masks
=
'groundtruth_instance_masks'
groundtruth_instance_mask_weights
=
'groundtruth_instance_mask_weights'
groundtruth_instance_boundaries
=
'groundtruth_instance_boundaries'
groundtruth_instance_classes
=
'groundtruth_instance_classes'
groundtruth_keypoints
=
'groundtruth_keypoints'
...
...
research/object_detection/data_decoders/tf_example_decoder.py
View file @
213a9649
...
...
@@ -373,6 +373,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
self
.
_decode_png_instance_masks
))
else
:
raise
ValueError
(
'Did not recognize the `instance_mask_type` option.'
)
self
.
keys_to_features
[
'image/object/mask/weight'
]
=
(
tf
.
VarLenFeature
(
tf
.
float32
))
self
.
items_to_handlers
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
]
=
(
slim_example_decoder
.
Tensor
(
'image/object/mask/weight'
))
if
load_dense_pose
:
self
.
keys_to_features
[
'image/object/densepose/num'
]
=
(
tf
.
VarLenFeature
(
tf
.
int64
))
...
...
@@ -491,6 +496,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor of shape [None, num_keypoints] containing keypoint visibilites.
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_mask_weights - 1D float32
tensor of shape [None] containing weights. These are typically values
in {0.0, 1.0} which indicate whether to consider the mask related to an
object.
fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape
[None] containing classes for the boxes.
fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
...
...
@@ -531,6 +540,21 @@ class TfExampleDecoder(data_decoder.DataDecoder):
0
),
lambda
:
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_weights
],
default_groundtruth_weights
)
if
fields
.
InputDataFields
.
groundtruth_instance_masks
in
tensor_dict
:
gt_instance_masks
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
num_gt_instance_masks
=
tf
.
shape
(
gt_instance_masks
)[
0
]
gt_instance_mask_weights
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
]
num_gt_instance_mask_weights
=
tf
.
shape
(
gt_instance_mask_weights
)[
0
]
def
default_groundtruth_instance_mask_weights
():
return
tf
.
ones
([
num_gt_instance_masks
],
dtype
=
tf
.
float32
)
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
]
=
(
tf
.
cond
(
tf
.
greater
(
num_gt_instance_mask_weights
,
0
),
lambda
:
gt_instance_mask_weights
,
default_groundtruth_instance_mask_weights
))
if
fields
.
InputDataFields
.
groundtruth_keypoints
in
tensor_dict
:
# Set all keypoints that are not labeled to NaN.
gt_kpt_fld
=
fields
.
InputDataFields
.
groundtruth_keypoints
...
...
research/object_detection/data_decoders/tf_example_decoder_test.py
View file @
213a9649
...
...
@@ -1225,6 +1225,9 @@ class TfExampleDecoderTest(test_case.TestCase):
self
.
assertAllEqual
(
instance_masks
.
astype
(
np
.
float32
),
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
])
self
.
assertAllEqual
(
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
],
[
1
,
1
,
1
,
1
])
self
.
assertAllEqual
(
object_classes
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
])
...
...
@@ -1272,6 +1275,71 @@ class TfExampleDecoderTest(test_case.TestCase):
self
.
assertNotIn
(
fields
.
InputDataFields
.
groundtruth_instance_masks
,
tensor_dict
)
def
testDecodeInstanceSegmentationWithWeights
(
self
):
num_instances
=
4
image_height
=
5
image_width
=
3
# Randomly generate image.
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
image_height
,
image_width
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
,
_
=
self
.
_create_encoded_and_decoded_data
(
image_tensor
,
'jpeg'
)
# Randomly generate instance segmentation masks.
instance_masks
=
(
np
.
random
.
randint
(
2
,
size
=
(
num_instances
,
image_height
,
image_width
)).
astype
(
np
.
float32
))
instance_masks_flattened
=
np
.
reshape
(
instance_masks
,
[
-
1
])
instance_mask_weights
=
np
.
array
([
1
,
1
,
0
,
1
],
dtype
=
np
.
float32
)
# Randomly generate class labels for each instance.
object_classes
=
np
.
random
.
randint
(
100
,
size
=
(
num_instances
)).
astype
(
np
.
int64
)
def
graph_fn
():
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
dataset_util
.
bytes_feature
(
encoded_jpeg
),
'image/format'
:
dataset_util
.
bytes_feature
(
six
.
b
(
'jpeg'
)),
'image/height'
:
dataset_util
.
int64_feature
(
image_height
),
'image/width'
:
dataset_util
.
int64_feature
(
image_width
),
'image/object/mask'
:
dataset_util
.
float_list_feature
(
instance_masks_flattened
),
'image/object/mask/weight'
:
dataset_util
.
float_list_feature
(
instance_mask_weights
),
'image/object/class/label'
:
dataset_util
.
int64_list_feature
(
object_classes
)
})).
SerializeToString
()
example_decoder
=
tf_example_decoder
.
TfExampleDecoder
(
load_instance_masks
=
True
)
output
=
example_decoder
.
decode
(
tf
.
convert_to_tensor
(
example
))
self
.
assertAllEqual
(
(
output
[
fields
.
InputDataFields
.
groundtruth_instance_masks
].
get_shape
(
).
as_list
()),
[
4
,
5
,
3
])
self
.
assertAllEqual
(
output
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
],
[
1
,
1
,
0
,
1
])
self
.
assertAllEqual
((
output
[
fields
.
InputDataFields
.
groundtruth_classes
].
get_shape
().
as_list
()),
[
4
])
return
output
tensor_dict
=
self
.
execute_cpu
(
graph_fn
,
[])
self
.
assertAllEqual
(
instance_masks
.
astype
(
np
.
float32
),
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
])
self
.
assertAllEqual
(
object_classes
,
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
])
def
testDecodeImageLabels
(
self
):
image_tensor
=
np
.
random
.
randint
(
256
,
size
=
(
4
,
5
,
3
)).
astype
(
np
.
uint8
)
encoded_jpeg
,
_
=
self
.
_create_encoded_and_decoded_data
(
...
...
research/object_detection/inputs.py
View file @
213a9649
...
...
@@ -479,6 +479,7 @@ def pad_input_data_to_static_shapes(tensor_dict,
input_fields
.
groundtruth_instance_masks
:
[
max_num_boxes
,
height
,
width
],
input_fields
.
groundtruth_instance_mask_weights
:
[
max_num_boxes
],
input_fields
.
groundtruth_is_crowd
:
[
max_num_boxes
],
input_fields
.
groundtruth_group_of
:
[
max_num_boxes
],
input_fields
.
groundtruth_area
:
[
max_num_boxes
],
...
...
@@ -601,6 +602,8 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_instance_masks
=
(
fields
.
InputDataFields
.
groundtruth_instance_masks
in
tensor_dict
)
include_instance_mask_weights
=
(
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
in
tensor_dict
)
include_keypoints
=
(
fields
.
InputDataFields
.
groundtruth_keypoints
in
tensor_dict
)
include_keypoint_visibilities
=
(
...
...
@@ -624,6 +627,7 @@ def augment_input_data(tensor_dict, data_augmentation_options):
include_label_confidences
=
include_label_confidences
,
include_multiclass_scores
=
include_multiclass_scores
,
include_instance_masks
=
include_instance_masks
,
include_instance_mask_weights
=
include_instance_mask_weights
,
include_keypoints
=
include_keypoints
,
include_keypoint_visibilities
=
include_keypoint_visibilities
,
include_dense_pose
=
include_dense_pose
,
...
...
@@ -652,6 +656,7 @@ def _get_labels_dict(input_dict):
fields
.
InputDataFields
.
groundtruth_keypoint_depths
,
fields
.
InputDataFields
.
groundtruth_keypoint_depth_weights
,
fields
.
InputDataFields
.
groundtruth_instance_masks
,
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
,
fields
.
InputDataFields
.
groundtruth_area
,
fields
.
InputDataFields
.
groundtruth_is_crowd
,
fields
.
InputDataFields
.
groundtruth_group_of
,
...
...
@@ -804,6 +809,9 @@ def train_input(train_config, train_input_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[batch_size, num_boxes] float32 tensor containing groundtruth weights
for each instance mask.
labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box.
...
...
@@ -961,6 +969,9 @@ def eval_input(eval_config, eval_input_config, model_config,
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[1, num_boxes, H, W] float32 tensor containing only binary values,
which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_instance_mask_weights] is a
[1, num_boxes] float32 tensor containing groundtruth weights for each
instance mask.
labels[fields.InputDataFields.groundtruth_weights] is a
[batch_size, num_boxes, num_keypoints] float32 tensor containing
groundtruth weights for the keypoints.
...
...
research/object_detection/inputs_test.py
View file @
213a9649
...
...
@@ -795,15 +795,20 @@ class DataAugmentationFnTest(test_case.TestCase):
fields
.
InputDataFields
.
image
:
tf
.
constant
(
np
.
random
.
rand
(
10
,
10
,
3
).
astype
(
np
.
float32
)),
fields
.
InputDataFields
.
groundtruth_instance_masks
:
tf
.
constant
(
np
.
zeros
([
2
,
10
,
10
],
np
.
uint8
))
tf
.
constant
(
np
.
zeros
([
2
,
10
,
10
],
np
.
uint8
)),
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
:
tf
.
constant
([
1.0
,
0.0
],
np
.
float32
)
}
augmented_tensor_dict
=
data_augmentation_fn
(
tensor_dict
=
tensor_dict
)
return
(
augmented_tensor_dict
[
fields
.
InputDataFields
.
image
],
augmented_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
])
image
,
masks
=
self
.
execute_cpu
(
graph_fn
,
[])
groundtruth_instance_masks
],
augmented_tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_mask_weights
])
image
,
masks
,
mask_weights
=
self
.
execute_cpu
(
graph_fn
,
[])
self
.
assertAllEqual
(
image
.
shape
,
[
20
,
20
,
3
])
self
.
assertAllEqual
(
masks
.
shape
,
[
2
,
20
,
20
])
self
.
assertAllClose
(
mask_weights
,
[
1.0
,
0.0
])
def
test_include_keypoints_in_data_augmentation
(
self
):
data_augmentation_options
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment