Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cc30e3e9
Commit
cc30e3e9
authored
Aug 03, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 388566700
parent
7797ebad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
4 deletions
+31
-4
official/vision/beta/configs/decoders.py
official/vision/beta/configs/decoders.py
+1
-0
official/vision/beta/modeling/decoders/aspp.py
official/vision/beta/modeling/decoders/aspp.py
+7
-1
official/vision/beta/modeling/decoders/aspp_test.py
official/vision/beta/modeling/decoders/aspp_test.py
+1
-0
official/vision/keras_cv/layers/deeplab.py
official/vision/keras_cv/layers/deeplab.py
+22
-3
No files found.
official/vision/beta/configs/decoders.py
View file @
cc30e3e9
...
...
@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
dropout_rate
:
float
=
0.0
num_filters
:
int
=
256
use_depthwise_convolution
:
bool
=
False
pool_kernel_size
:
Optional
[
List
[
int
]]
=
None
# Use global average pooling.
...
...
official/vision/beta/modeling/decoders/aspp.py
View file @
cc30e3e9
...
...
@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
interpolation
:
str
=
'bilinear'
,
use_depthwise_convolution
:
bool
=
False
,
**
kwargs
):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...
...
@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer):
interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
**kwargs: Additional keyword arguments to be passed.
"""
super
(
ASPP
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer):
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'interpolation'
:
interpolation
,
'use_depthwise_convolution'
:
use_depthwise_convolution
,
}
def
build
(
self
,
input_shape
):
...
...
@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer):
dropout
=
self
.
_config_dict
[
'dropout_rate'
],
kernel_initializer
=
self
.
_config_dict
[
'kernel_initializer'
],
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
interpolation
=
self
.
_config_dict
[
'interpolation'
])
interpolation
=
self
.
_config_dict
[
'interpolation'
],
use_depthwise_convolution
=
self
.
_config_dict
[
'use_depthwise_convolution'
]
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
tf
.
Tensor
])
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
...
...
official/vision/beta/modeling/decoders/aspp_test.py
View file @
cc30e3e9
...
...
@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
kernel_regularizer
=
None
,
interpolation
=
'bilinear'
,
dropout_rate
=
0.2
,
use_depthwise_convolution
=
'false'
,
)
network
=
aspp
.
ASPP
(
**
kwargs
)
...
...
official/vision/keras_cv/layers/deeplab.py
View file @
cc30e3e9
...
...
@@ -21,9 +21,11 @@ import tensorflow as tf
class
SpatialPyramidPooling
(
tf
.
keras
.
layers
.
Layer
):
"""Implements the Atrous Spatial Pyramid Pooling.
Reference:
Reference
s
:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
def
__init__
(
...
...
@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_initializer
=
'glorot_uniform'
,
kernel_regularizer
=
None
,
interpolation
=
'bilinear'
,
use_depthwise_convolution
=
False
,
**
kwargs
):
"""Initializes `SpatialPyramidPooling`.
...
...
@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
"""
super
(
SpatialPyramidPooling
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self
.
interpolation
=
interpolation
self
.
input_spec
=
tf
.
keras
.
layers
.
InputSpec
(
ndim
=
4
)
self
.
pool_kernel_size
=
pool_kernel_size
self
.
use_depthwise_convolution
=
use_depthwise_convolution
def
build
(
self
,
input_shape
):
height
=
input_shape
[
1
]
...
...
@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self
.
aspp_layers
.
append
(
conv_sequential
)
for
dilation_rate
in
self
.
dilation_rates
:
conv_sequential
=
tf
.
keras
.
Sequential
([
leading_layers
=
[]
kernel_size
=
(
3
,
3
)
if
self
.
use_depthwise_convolution
:
leading_layers
+=
[
tf
.
keras
.
layers
.
DepthwiseConv2D
(
depth_multiplier
=
1
,
kernel_size
=
kernel_size
,
padding
=
'same'
,
depthwise_regularizer
=
self
.
kernel_regularizer
,
depthwise_initializer
=
self
.
kernel_initializer
,
dilation_rate
=
dilation_rate
,
use_bias
=
False
)
]
kernel_size
=
(
1
,
1
)
conv_sequential
=
tf
.
keras
.
Sequential
(
leading_layers
+
[
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
output_channels
,
kernel_size
=
(
3
,
3
)
,
filters
=
self
.
output_channels
,
kernel_size
=
kernel_size
,
padding
=
'same'
,
kernel_regularizer
=
self
.
kernel_regularizer
,
kernel_initializer
=
self
.
kernel_initializer
,
dilation_rate
=
dilation_rate
,
use_bias
=
False
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment