Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b92025a9
Commit
b92025a9
authored
Aug 18, 2021
by
anivegesana
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into detection_generator_pr_2
parents
1b425791
37536370
Changes
108
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
34 deletions
+88
-34
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+4
-1
official/vision/keras_cv/layers/deeplab.py
official/vision/keras_cv/layers/deeplab.py
+22
-3
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+7
-2
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+34
-12
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+9
-5
research/object_detection/packages/tf2/setup.py
research/object_detection/packages/tf2/setup.py
+1
-5
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+6
-4
research/slim/nets/mobilenet/mobilenet_example.ipynb
research/slim/nets/mobilenet/mobilenet_example.ipynb
+5
-2
No files found.
official/vision/beta/tasks/maskrcnn.py
View file @
b92025a9
...
@@ -261,12 +261,15 @@ class MaskRCNNTask(base_task.Task):
...
@@ -261,12 +261,15 @@ class MaskRCNNTask(base_task.Task):
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
else
:
else
:
if
self
.
_task_config
.
annotation_file
:
if
(
not
self
.
_task_config
.
model
.
include_mask
)
or
self
.
_task_config
.
annotation_file
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
_task_config
.
annotation_file
,
annotation_file
=
self
.
_task_config
.
annotation_file
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
else
:
else
:
# Builds COCO-style annotation file if include_mask is True, and
# annotation_file isn't provided.
annotation_path
=
os
.
path
.
join
(
self
.
_logging_dir
,
'annotation.json'
)
annotation_path
=
os
.
path
.
join
(
self
.
_logging_dir
,
'annotation.json'
)
if
tf
.
io
.
gfile
.
exists
(
annotation_path
):
if
tf
.
io
.
gfile
.
exists
(
annotation_path
):
logging
.
info
(
logging
.
info
(
...
...
official/vision/keras_cv/layers/deeplab.py
View file @
b92025a9
...
@@ -21,9 +21,11 @@ import tensorflow as tf
...
@@ -21,9 +21,11 @@ import tensorflow as tf
class
SpatialPyramidPooling
(
tf
.
keras
.
layers
.
Layer
):
class
SpatialPyramidPooling
(
tf
.
keras
.
layers
.
Layer
):
"""Implements the Atrous Spatial Pyramid Pooling.
"""Implements the Atrous Spatial Pyramid Pooling.
Reference:
Reference
s
:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
"""
def
__init__
(
def
__init__
(
...
@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
...
@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_initializer
=
'glorot_uniform'
,
kernel_initializer
=
'glorot_uniform'
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
interpolation
=
'bilinear'
,
interpolation
=
'bilinear'
,
use_depthwise_convolution
=
False
,
**
kwargs
):
**
kwargs
):
"""Initializes `SpatialPyramidPooling`.
"""Initializes `SpatialPyramidPooling`.
...
@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
...
@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
**kwargs: Other keyword arguments for the layer.
"""
"""
super
(
SpatialPyramidPooling
,
self
).
__init__
(
**
kwargs
)
super
(
SpatialPyramidPooling
,
self
).
__init__
(
**
kwargs
)
...
@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
...
@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
self
.
input_spec
=
tf
.
keras
.
layers
.
InputSpec
(
ndim
=
4
)
self
.
input_spec
=
tf
.
keras
.
layers
.
InputSpec
(
ndim
=
4
)
self
.
pool_kernel_size
=
pool_kernel_size
self
.
pool_kernel_size
=
pool_kernel_size
self
.
use_depthwise_convolution
=
use_depthwise_convolution
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
height
=
input_shape
[
1
]
height
=
input_shape
[
1
]
...
@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
...
@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self
.
aspp_layers
.
append
(
conv_sequential
)
self
.
aspp_layers
.
append
(
conv_sequential
)
for
dilation_rate
in
self
.
dilation_rates
:
for
dilation_rate
in
self
.
dilation_rates
:
conv_sequential
=
tf
.
keras
.
Sequential
([
leading_layers
=
[]
kernel_size
=
(
3
,
3
)
if
self
.
use_depthwise_convolution
:
leading_layers
+=
[
tf
.
keras
.
layers
.
DepthwiseConv2D
(
depth_multiplier
=
1
,
kernel_size
=
kernel_size
,
padding
=
'same'
,
depthwise_regularizer
=
self
.
kernel_regularizer
,
depthwise_initializer
=
self
.
kernel_initializer
,
dilation_rate
=
dilation_rate
,
use_bias
=
False
)
]
kernel_size
=
(
1
,
1
)
conv_sequential
=
tf
.
keras
.
Sequential
(
leading_layers
+
[
tf
.
keras
.
layers
.
Conv2D
(
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
output_channels
,
kernel_size
=
(
3
,
3
)
,
filters
=
self
.
output_channels
,
kernel_size
=
kernel_size
,
padding
=
'same'
,
kernel_regularizer
=
self
.
kernel_regularizer
,
padding
=
'same'
,
kernel_regularizer
=
self
.
kernel_regularizer
,
kernel_initializer
=
self
.
kernel_initializer
,
kernel_initializer
=
self
.
kernel_initializer
,
dilation_rate
=
dilation_rate
,
use_bias
=
False
),
dilation_rate
=
dilation_rate
,
use_bias
=
False
),
...
...
research/object_detection/core/target_assigner.py
View file @
b92025a9
...
@@ -961,7 +961,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -961,7 +961,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
width
,
width
,
gt_boxes_list
,
gt_boxes_list
,
gt_classes_list
,
gt_classes_list
,
gt_weights_list
=
None
):
gt_weights_list
=
None
,
maximum_normalized_coordinate
=
1.1
):
"""Computes the object center heatmap target.
"""Computes the object center heatmap target.
Args:
Args:
...
@@ -977,6 +978,9 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -977,6 +978,9 @@ class CenterNetCenterHeatmapTargetAssigner(object):
in the gt_boxes_list.
in the gt_boxes_list.
gt_weights_list: A list of float tensors with shape [num_boxes]
gt_weights_list: A list of float tensors with shape [num_boxes]
representing the weight of each groundtruth detection box.
representing the weight of each groundtruth detection box.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
Returns:
heatmap: A Tensor of size [batch_size, output_height, output_width,
heatmap: A Tensor of size [batch_size, output_height, output_width,
...
@@ -1002,7 +1006,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -1002,7 +1006,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
,
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
maximum_normalized_coordinate
=
maximum_normalized_coordinate
)
# Get the box center coordinates. Each returned tensors have the shape of
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
# [num_instances]
(
y_center
,
x_center
,
boxes_height
,
(
y_center
,
x_center
,
boxes_height
,
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
b92025a9
...
@@ -2714,7 +2714,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2714,7 +2714,8 @@ class CenterNetMetaArch(model.DetectionModel):
return
target_assigners
return
target_assigners
def
_compute_object_center_loss
(
self
,
input_height
,
input_width
,
def
_compute_object_center_loss
(
self
,
input_height
,
input_width
,
object_center_predictions
,
per_pixel_weights
):
object_center_predictions
,
per_pixel_weights
,
maximum_normalized_coordinate
=
1.1
):
"""Computes the object center loss.
"""Computes the object center loss.
Args:
Args:
...
@@ -2726,6 +2727,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2726,6 +2727,9 @@ class CenterNetMetaArch(model.DetectionModel):
per_pixel_weights: A float tensor of shape [batch_size,
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
coordinates fall within the height and width in true_image_shapes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
Returns:
A float scalar tensor representing the object center loss per instance.
A float scalar tensor representing the object center loss per instance.
...
@@ -2752,7 +2756,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2752,7 +2756,8 @@ class CenterNetMetaArch(model.DetectionModel):
width
=
input_width
,
width
=
input_width
,
gt_classes_list
=
gt_classes_list
,
gt_classes_list
=
gt_classes_list
,
gt_keypoints_list
=
gt_keypoints_list
,
gt_keypoints_list
=
gt_keypoints_list
,
gt_weights_list
=
gt_weights_list
)
gt_weights_list
=
gt_weights_list
,
maximum_normalized_coordinate
=
maximum_normalized_coordinate
)
else
:
else
:
gt_boxes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
boxes
)
gt_boxes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
boxes
)
heatmap_targets
=
assigner
.
assign_center_targets_from_boxes
(
heatmap_targets
=
assigner
.
assign_center_targets_from_boxes
(
...
@@ -2760,7 +2765,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2760,7 +2765,8 @@ class CenterNetMetaArch(model.DetectionModel):
width
=
input_width
,
width
=
input_width
,
gt_boxes_list
=
gt_boxes_list
,
gt_boxes_list
=
gt_boxes_list
,
gt_classes_list
=
gt_classes_list
,
gt_classes_list
=
gt_classes_list
,
gt_weights_list
=
gt_weights_list
)
gt_weights_list
=
gt_weights_list
,
maximum_normalized_coordinate
=
maximum_normalized_coordinate
)
flattened_heatmap_targets
=
_flatten_spatial_dimensions
(
heatmap_targets
)
flattened_heatmap_targets
=
_flatten_spatial_dimensions
(
heatmap_targets
)
num_boxes
=
_to_float32
(
get_num_instances_from_weights
(
gt_weights_list
))
num_boxes
=
_to_float32
(
get_num_instances_from_weights
(
gt_weights_list
))
...
@@ -3577,7 +3583,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3577,7 +3583,9 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_batched_prediction_tensor_names
=
predictions
.
keys
()
self
.
_batched_prediction_tensor_names
=
predictions
.
keys
()
return
predictions
return
predictions
def
loss
(
self
,
prediction_dict
,
true_image_shapes
,
scope
=
None
):
def
loss
(
self
,
prediction_dict
,
true_image_shapes
,
scope
=
None
,
maximum_normalized_coordinate
=
1.1
):
"""Computes scalar loss tensors with respect to provided groundtruth.
"""Computes scalar loss tensors with respect to provided groundtruth.
This function implements the various CenterNet losses.
This function implements the various CenterNet losses.
...
@@ -3589,6 +3597,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3589,6 +3597,9 @@ class CenterNetMetaArch(model.DetectionModel):
the form [height, width, channels] indicating the shapes of true images
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
in the resized images, as resized images can be padded with zeros.
scope: Optional scope name.
scope: Optional scope name.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
Returns:
A dictionary mapping the keys [
A dictionary mapping the keys [
...
@@ -3616,7 +3627,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3616,7 +3627,7 @@ class CenterNetMetaArch(model.DetectionModel):
# TODO(vighneshb) Explore whether using floor here is safe.
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes
=
tf
.
ceil
(
output_true_image_shapes
=
tf
.
ceil
(
tf
.
to_floa
t
(
true_image_shapes
)
/
self
.
_stride
)
tf
.
cas
t
(
true_image_shapes
,
tf
.
float32
)
/
self
.
_stride
)
valid_anchor_weights
=
get_valid_anchor_weights_in_flattened_image
(
valid_anchor_weights
=
get_valid_anchor_weights_in_flattened_image
(
output_true_image_shapes
,
output_height
,
output_width
)
output_true_image_shapes
,
output_height
,
output_width
)
valid_anchor_weights
=
tf
.
expand_dims
(
valid_anchor_weights
,
2
)
valid_anchor_weights
=
tf
.
expand_dims
(
valid_anchor_weights
,
2
)
...
@@ -3625,7 +3636,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3625,7 +3636,8 @@ class CenterNetMetaArch(model.DetectionModel):
object_center_predictions
=
prediction_dict
[
OBJECT_CENTER
],
object_center_predictions
=
prediction_dict
[
OBJECT_CENTER
],
input_height
=
input_height
,
input_height
=
input_height
,
input_width
=
input_width
,
input_width
=
input_width
,
per_pixel_weights
=
valid_anchor_weights
)
per_pixel_weights
=
valid_anchor_weights
,
maximum_normalized_coordinate
=
maximum_normalized_coordinate
)
losses
=
{
losses
=
{
OBJECT_CENTER
:
OBJECT_CENTER
:
self
.
_center_params
.
object_center_loss_weight
*
object_center_loss
self
.
_center_params
.
object_center_loss_weight
*
object_center_loss
...
@@ -3742,10 +3754,20 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3742,10 +3754,20 @@ class CenterNetMetaArch(model.DetectionModel):
"""
"""
object_center_prob
=
tf
.
nn
.
sigmoid
(
prediction_dict
[
OBJECT_CENTER
][
-
1
])
object_center_prob
=
tf
.
nn
.
sigmoid
(
prediction_dict
[
OBJECT_CENTER
][
-
1
])
# Mask object centers by true_image_shape. [batch, h, w, 1]
if
true_image_shapes
is
None
:
object_center_mask
=
mask_from_true_image_shape
(
# If true_image_shapes is not provided, we assume the whole image is valid
_get_shape
(
object_center_prob
,
4
),
true_image_shapes
)
# and infer the true_image_shapes from the object_center_prob shape.
object_center_prob
*=
object_center_mask
batch_size
,
strided_height
,
strided_width
,
_
=
_get_shape
(
object_center_prob
,
4
)
true_image_shapes
=
tf
.
stack
(
[
strided_height
*
self
.
_stride
,
strided_width
*
self
.
_stride
,
tf
.
constant
(
len
(
self
.
_feature_extractor
.
_channel_means
))])
# pylint: disable=protected-access
true_image_shapes
=
tf
.
stack
([
true_image_shapes
]
*
batch_size
,
axis
=
0
)
else
:
# Mask object centers by true_image_shape. [batch, h, w, 1]
object_center_mask
=
mask_from_true_image_shape
(
_get_shape
(
object_center_prob
,
4
),
true_image_shapes
)
object_center_prob
*=
object_center_mask
# Get x, y and channel indices corresponding to the top indices in the class
# Get x, y and channel indices corresponding to the top indices in the class
# center predictions.
# center predictions.
...
@@ -3755,8 +3777,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3755,8 +3777,8 @@ class CenterNetMetaArch(model.DetectionModel):
k
=
self
.
_center_params
.
max_box_predictions
))
k
=
self
.
_center_params
.
max_box_predictions
))
multiclass_scores
=
tf
.
gather_nd
(
multiclass_scores
=
tf
.
gather_nd
(
object_center_prob
,
tf
.
stack
([
y_indices
,
x_indices
],
-
1
),
batch_dims
=
1
)
object_center_prob
,
tf
.
stack
([
y_indices
,
x_indices
],
-
1
),
batch_dims
=
1
)
num_detections
=
tf
.
reduce_sum
(
num_detections
=
tf
.
reduce_sum
(
tf
.
to_int32
(
detection_scores
>
0
),
axis
=
1
)
tf
.
cast
(
detection_scores
>
0
,
tf
.
int32
),
axis
=
1
)
postprocess_dict
=
{
postprocess_dict
=
{
fields
.
DetectionResultFields
.
detection_scores
:
detection_scores
,
fields
.
DetectionResultFields
.
detection_scores
:
detection_scores
,
fields
.
DetectionResultFields
.
detection_multiclass_scores
:
fields
.
DetectionResultFields
.
detection_multiclass_scores
:
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
b92025a9
...
@@ -2056,10 +2056,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
...
@@ -2056,10 +2056,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
cnma
.
TEMPORAL_OFFSET
)])
cnma
.
TEMPORAL_OFFSET
)])
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
{
'target_class_id'
:
1
},
{
'target_class_id'
:
1
,
'with_true_image_shape'
:
True
},
{
'target_class_id'
:
2
},
{
'target_class_id'
:
2
,
'with_true_image_shape'
:
True
},
{
'target_class_id'
:
1
,
'with_true_image_shape'
:
False
},
)
)
def
test_postprocess
(
self
,
target_class_id
):
def
test_postprocess
(
self
,
target_class_id
,
with_true_image_shape
):
"""Test the postprocess function."""
"""Test the postprocess function."""
model
=
build_center_net_meta_arch
()
model
=
build_center_net_meta_arch
()
max_detection
=
model
.
_center_params
.
max_box_predictions
max_detection
=
model
.
_center_params
.
max_box_predictions
...
@@ -2140,8 +2141,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
...
@@ -2140,8 +2141,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
}
}
def
graph_fn
():
def
graph_fn
():
detections
=
model
.
postprocess
(
prediction_dict
,
if
with_true_image_shape
:
tf
.
constant
([[
128
,
128
,
3
]]))
detections
=
model
.
postprocess
(
prediction_dict
,
tf
.
constant
([[
128
,
128
,
3
]]))
else
:
detections
=
model
.
postprocess
(
prediction_dict
,
None
)
return
detections
return
detections
detections
=
self
.
execute_cpu
(
graph_fn
,
[])
detections
=
self
.
execute_cpu
(
graph_fn
,
[])
...
...
research/object_detection/packages/tf2/setup.py
View file @
b92025a9
...
@@ -21,11 +21,7 @@ REQUIRED_PACKAGES = [
...
@@ -21,11 +21,7 @@ REQUIRED_PACKAGES = [
'lvis'
,
'lvis'
,
'scipy'
,
'scipy'
,
'pandas'
,
'pandas'
,
# tensorflow 2.5.0 requires grpcio~=1.34.0.
'tf-models-official>=2.5.1'
,
# tf-models-official (which requires google-could-bigquery) ends
# up installing the latest grpcio which causes problems later.
'google-cloud-bigquery==1.21.0'
,
'tf-models-official'
,
]
]
setup
(
setup
(
...
...
research/object_detection/utils/ops.py
View file @
b92025a9
...
@@ -948,7 +948,8 @@ def merge_boxes_with_multiple_labels(boxes,
...
@@ -948,7 +948,8 @@ def merge_boxes_with_multiple_labels(boxes,
def
nearest_neighbor_upsampling
(
input_tensor
,
scale
=
None
,
height_scale
=
None
,
def
nearest_neighbor_upsampling
(
input_tensor
,
scale
=
None
,
height_scale
=
None
,
width_scale
=
None
):
width_scale
=
None
,
name
=
'nearest_neighbor_upsampling'
):
"""Nearest neighbor upsampling implementation.
"""Nearest neighbor upsampling implementation.
Nearest neighbor upsampling function that maps input tensor with shape
Nearest neighbor upsampling function that maps input tensor with shape
...
@@ -965,6 +966,7 @@ def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None,
...
@@ -965,6 +966,7 @@ def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None,
option when provided overrides `scale` option.
option when provided overrides `scale` option.
width_scale: An integer multiple to scale the width of input image. This
width_scale: An integer multiple to scale the width of input image. This
option when provided overrides `scale` option.
option when provided overrides `scale` option.
name: A name for the operation (optional).
Returns:
Returns:
data_up: A float32 tensor of size
data_up: A float32 tensor of size
[batch, height_in*scale, width_in*scale, channels].
[batch, height_in*scale, width_in*scale, channels].
...
@@ -976,13 +978,13 @@ def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None,
...
@@ -976,13 +978,13 @@ def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None,
if
not
scale
and
(
height_scale
is
None
or
width_scale
is
None
):
if
not
scale
and
(
height_scale
is
None
or
width_scale
is
None
):
raise
ValueError
(
'Provide either `scale` or `height_scale` and'
raise
ValueError
(
'Provide either `scale` or `height_scale` and'
' `width_scale`.'
)
' `width_scale`.'
)
with
tf
.
name_scope
(
'nearest_neighbor_upsampling'
):
with
tf
.
name_scope
(
name
):
h_scale
=
scale
if
height_scale
is
None
else
height_scale
h_scale
=
scale
if
height_scale
is
None
else
height_scale
w_scale
=
scale
if
width_scale
is
None
else
width_scale
w_scale
=
scale
if
width_scale
is
None
else
width_scale
(
batch_size
,
height
,
width
,
(
batch_size
,
height
,
width
,
channels
)
=
shape_utils
.
combined_static_and_dynamic_shape
(
input_tensor
)
channels
)
=
shape_utils
.
combined_static_and_dynamic_shape
(
input_tensor
)
output_tensor
=
tf
.
stack
([
input_tensor
]
*
w_scale
,
axis
=
3
)
output_tensor
=
tf
.
stack
([
input_tensor
]
*
w_scale
,
axis
=
3
,
name
=
'w_stack'
)
output_tensor
=
tf
.
stack
([
output_tensor
]
*
h_scale
,
axis
=
2
)
output_tensor
=
tf
.
stack
([
output_tensor
]
*
h_scale
,
axis
=
2
,
name
=
'h_stack'
)
return
tf
.
reshape
(
output_tensor
,
return
tf
.
reshape
(
output_tensor
,
[
batch_size
,
height
*
h_scale
,
width
*
w_scale
,
channels
])
[
batch_size
,
height
*
h_scale
,
width
*
w_scale
,
channels
])
...
...
research/slim/nets/mobilenet/mobilenet_example.ipynb
View file @
b92025a9
...
@@ -197,9 +197,10 @@
...
@@ -197,9 +197,10 @@
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# setup path\n",
"# setup path
and install tf-slim
\n",
"import sys\n",
"import sys\n",
"sys.path.append('/content/models/research/slim')"
"sys.path.append('/content/models/research/slim')"
"!pip install tf_slim",
]
]
},
},
{
{
...
@@ -228,8 +229,10 @@
...
@@ -228,8 +229,10 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"import tensorflow.compat.v1 as tf\n",
"import tensorflow.compat.v1 as tf\n",
"import tf_slim as slim\n",
"from nets.mobilenet import mobilenet_v2\n",
"from nets.mobilenet import mobilenet_v2\n",
"\n",
"\n",
"tf.compat.v1.disable_eager_execution()\n"
"tf.reset_default_graph()\n",
"tf.reset_default_graph()\n",
"\n",
"\n",
"# For simplicity we just decode jpeg inside tensorflow.\n",
"# For simplicity we just decode jpeg inside tensorflow.\n",
...
@@ -244,7 +247,7 @@
...
@@ -244,7 +247,7 @@
"images = tf.image.resize_images(images, (224, 224))\n",
"images = tf.image.resize_images(images, (224, 224))\n",
"\n",
"\n",
"# Note: arg_scope is optional for inference.\n",
"# Note: arg_scope is optional for inference.\n",
"with
tf.contrib.
slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):\n",
"with slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):\n",
" logits, endpoints = mobilenet_v2.mobilenet(images)\n",
" logits, endpoints = mobilenet_v2.mobilenet(images)\n",
" \n",
" \n",
"# Restore using exponential moving average since it produces (1.5-2%) higher \n",
"# Restore using exponential moving average since it produces (1.5-2%) higher \n",
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment