Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
12110e64
Commit
12110e64
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
finalize context rcnn tf2
parent
b44de920
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
5 deletions
+4
-5
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
...ject_detection/meta_architectures/context_rcnn_lib_tf2.py
+3
-4
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
...detection/meta_architectures/context_rcnn_lib_tf2_test.py
+1
-1
No files found.
research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
View file @
12110e64
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Library functions for Context
R
CNN."""
"""Library functions for Context
R-
CNN."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
freezable_batch_norm
from
object_detection.core
import
freezable_batch_norm
...
@@ -61,12 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
...
@@ -61,12 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def
build
(
self
,
input_shapes
):
def
build
(
self
,
input_shapes
):
if
not
self
.
_feature_proj
:
if
not
self
.
_feature_proj
:
self
.
_output_dimension
=
input_shapes
[
0
][
-
1
]
self
.
_output_dimension
=
input_shapes
[
-
1
]
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
)
self
.
_feature_proj
=
ContextProjection
(
self
.
_output_dimension
)
def
call
(
self
,
box_
and_
context_features
,
valid_context_size
):
def
call
(
self
,
box_
features
,
context_features
,
valid_context_size
):
"""Handles a call by performing attention."""
"""Handles a call by performing attention."""
box_features
,
context_features
=
box_and_context_features
_
,
context_size
,
_
=
context_features
.
shape
_
,
context_size
,
_
=
context_features
.
shape
valid_mask
=
compute_valid_mask
(
valid_context_size
,
context_size
)
valid_mask
=
compute_valid_mask
(
valid_context_size
,
context_size
)
...
...
research/object_detection/meta_architectures/context_rcnn_lib_tf2_test.py
View file @
12110e64
...
@@ -106,7 +106,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
...
@@ -106,7 +106,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
minval
=
0
,
minval
=
0
,
maxval
=
10
,
maxval
=
10
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
output_features
=
attention_block
(
[
input_features
,
context_features
]
,
valid_context_size
)
output_features
=
attention_block
(
input_features
,
context_features
,
valid_context_size
)
# Makes sure the shape is correct.
# Makes sure the shape is correct.
self
.
assertAllEqual
(
output_features
.
shape
,
[
2
,
8
,
1
,
1
,
(
output_dimension
or
3
)])
self
.
assertAllEqual
(
output_features
.
shape
,
[
2
,
8
,
1
,
1
,
(
output_dimension
or
3
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment