Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aabe1231
Commit
aabe1231
authored
Jul 22, 2020
by
Kaushik Shivakumar
Browse files
Context RCNN support for Tensorflow 2.x
parent
e3f88e11
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
17 deletions
+10
-17
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
...ject_detection/meta_architectures/context_rcnn_lib_tf2.py
+8
-12
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
...detection/meta_architectures/context_rcnn_lib_tf2_test.py
+2
-4
research/object_detection/meta_architectures/context_rcnn_meta_arch.py
...ct_detection/meta_architectures/context_rcnn_meta_arch.py
+0
-1
No files found.
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
View file @
aabe1231
...
...
@@ -23,13 +23,13 @@ _NEGATIVE_PADDING_VALUE = -100000
class
ContextProjection
(
tf
.
keras
.
layers
.
Layer
):
"""Custom layer to do batch normalization and projection."""
def
__init__
(
self
,
projection_dimension
,
freeze_batchnorm
,
**
kwargs
):
def
__init__
(
self
,
projection_dimension
,
**
kwargs
):
self
.
batch_norm
=
freezable_batch_norm
.
FreezableBatchNorm
(
epsilon
=
0.001
,
center
=
True
,
scale
=
True
,
momentum
=
0.97
,
trainable
=
(
not
freeze_batchnorm
)
)
trainable
=
True
)
self
.
projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
projection_dimension
,
activation
=
tf
.
nn
.
relu6
,
use_bias
=
True
)
...
...
@@ -45,33 +45,29 @@ class ContextProjection(tf.keras.layers.Layer):
class
AttentionBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Custom layer to perform all attention."""
def
__init__
(
self
,
bottleneck_dimension
,
attention_temperature
,
freeze_batchnorm
,
output_dimension
=
None
,
output_dimension
=
None
,
is_training
=
False
,
name
=
'AttentionBlock'
,
**
kwargs
):
self
.
_key_proj
=
ContextProjection
(
bottleneck_dimension
,
freeze_batchnorm
)
self
.
_val_proj
=
ContextProjection
(
bottleneck_dimension
,
freeze_batchnorm
)
self
.
_query_proj
=
ContextProjection
(
bottleneck_dimension
,
freeze_batchnorm
)
self
.
_key_proj
=
ContextProjection
(
bottleneck_dimension
)
self
.
_val_proj
=
ContextProjection
(
bottleneck_dimension
)
self
.
_query_proj
=
ContextProjection
(
bottleneck_dimension
)
self
.
_feature_proj
=
None
self
.
_attention_temperature
=
attention_temperature
self
.
_freeze_batchnorm
=
freeze_batchnorm
self
.
_bottleneck_dimension
=
bottleneck_dimension
self
.
_is_training
=
is_training
self
.
_output_dimension
=
output_dimension
if
self
.
_output_dimension
:
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
,
self
.
_freeze_batchnorm
)
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
)
super
(
AttentionBlock
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
def
build
(
self
,
input_shapes
):
if
not
self
.
_feature_proj
:
self
.
_output_dimension
=
input_shapes
[
-
1
]
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
,
self
.
_freeze_batchnorm
)
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
)
def
call
(
self
,
box_features
,
context_features
,
valid_context_size
):
"""Handles a call by performing attention."""
_
,
context_size
,
_
=
context_features
.
shape
valid_mask
=
compute_valid_mask
(
valid_context_size
,
context_size
)
channels
=
box_features
.
shape
[
-
1
]
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
...
...
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
View file @
aabe1231
...
...
@@ -80,7 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
features
,
projection_dimension
,
is_training
,
context_rcnn_lib
.
ContextProjection
(
projection_dimension
,
False
),
context_rcnn_lib
.
ContextProjection
(
projection_dimension
),
normalize
=
normalize
)
# Makes sure the shape is correct.
...
...
@@ -97,10 +97,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
attention_temperature
):
input_features
=
tf
.
ones
([
2
,
8
,
3
,
3
,
3
],
tf
.
float32
)
context_features
=
tf
.
ones
([
2
,
20
,
10
],
tf
.
float32
)
is_training
=
False
attention_block
=
context_rcnn_lib
.
AttentionBlock
(
bottleneck_dimension
,
attention_temperature
,
freeze_batchnorm
=
False
,
attention_temperature
,
output_dimension
=
output_dimension
,
is_training
=
False
)
valid_context_size
=
tf
.
random_uniform
((
2
,),
...
...
research/object_detection/meta_architectures/context_rcnn_meta_arch.py
View file @
aabe1231
...
...
@@ -275,7 +275,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
self
.
_context_feature_extract_fn
=
context_rcnn_lib_tf2
.
AttentionBlock
(
bottleneck_dimension
=
attention_bottleneck_dimension
,
attention_temperature
=
attention_temperature
,
freeze_batchnorm
=
freeze_batchnorm
,
is_training
=
is_training
)
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment