Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b44de920
Commit
b44de920
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
finalize tf2 context rcnn
parent
f72dfa8d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
...ject_detection/meta_architectures/context_rcnn_lib_tf2.py
+4
-2
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
...detection/meta_architectures/context_rcnn_lib_tf2_test.py
+4
-4
No files found.
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
View file @
b44de920
...
...
@@ -61,11 +61,13 @@ class AttentionBlock(tf.keras.layers.Layer):
def
build
(
self
,
input_shapes
):
if
not
self
.
_feature_proj
:
self
.
_output_dimension
=
input_shapes
[
-
1
]
self
.
_output_dimension
=
input_shapes
[
0
][
-
1
]
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
)
def
call
(
self
,
box_
features
,
context_features
,
valid_context_size
):
def
call
(
self
,
box_
and_
context_features
,
valid_context_size
):
"""Handles a call by performing attention."""
box_features
,
context_features
=
box_and_context_features
_
,
context_size
,
_
=
context_features
.
shape
valid_mask
=
compute_valid_mask
(
valid_context_size
,
context_size
)
...
...
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
View file @
b44de920
...
...
@@ -90,9 +90,9 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
@
parameterized
.
parameters
(
(
2
,
10
,
1
),
(
3
,
10
,
2
),
(
4
,
20
,
3
),
(
4
,
None
,
3
),
(
5
,
20
,
4
),
(
7
,
20
,
5
),
(
7
,
None
,
5
),
)
def
test_attention_block
(
self
,
bottleneck_dimension
,
output_dimension
,
attention_temperature
):
...
...
@@ -106,10 +106,10 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
int32
)
output_features
=
attention_block
(
input_features
,
context_features
,
valid_context_size
)
output_features
=
attention_block
(
[
input_features
,
context_features
]
,
valid_context_size
)
# Makes sure the shape is correct.
self
.
assertAllEqual
(
output_features
.
shape
,
[
2
,
8
,
1
,
1
,
output_dimension
])
self
.
assertAllEqual
(
output_features
.
shape
,
[
2
,
8
,
1
,
1
,
(
output_dimension
or
3
)
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment