Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
66e8a904
Commit
66e8a904
authored
Jun 30, 2020
by
Kaushik Shivakumar
Browse files
make fixes
parent
3475ebda
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
research/object_detection/meta_architectures/context_rcnn_lib.py
...h/object_detection/meta_architectures/context_rcnn_lib.py
+6
-5
No files found.
research/object_detection/meta_architectures/context_rcnn_lib.py
View file @
66e8a904
...
@@ -61,10 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
...
@@ -61,10 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def
build
(
self
,
input_shapes
):
def
build
(
self
,
input_shapes
):
self
.
feature_proj
=
ContextProjection
(
input_shapes
[
0
][
-
1
],
self
.
freeze_batchnorm
)
self
.
feature_proj
=
ContextProjection
(
input_shapes
[
0
][
-
1
],
self
.
freeze_batchnorm
)
self
.
key_proj
.
build
(
input_shapes
[
0
])
#self.key_proj.build(input_shapes[0])
self
.
val_proj
.
build
(
input_shapes
[
0
])
#self.val_proj.build(input_shapes[0])
self
.
query_proj
.
build
(
input_shapes
[
0
])
#self.query_proj.build(input_shapes[0])
self
.
feature_proj
.
build
(
input_shapes
[
0
])
#self.feature_proj.build(input_shapes[0])
pass
def
filter_weight_value
(
self
,
weights
,
values
,
valid_mask
):
def
filter_weight_value
(
self
,
weights
,
values
,
valid_mask
):
"""Filters weights and values based on valid_mask.
"""Filters weights and values based on valid_mask.
...
@@ -140,7 +141,7 @@ class AttentionBlock(tf.keras.layers.Layer):
...
@@ -140,7 +141,7 @@ class AttentionBlock(tf.keras.layers.Layer):
projected_features
=
layer
(
features
,
is_training
)
projected_features
=
layer
(
features
,
is_training
)
projected_features
=
tf
.
reshape
(
(
batch_size
,
-
1
,
bottleneck_dimension
))(
projected_features
)
projected_features
=
tf
.
reshape
(
projected_features
,
[
batch_size
,
-
1
,
bottleneck_dimension
]
)
print
(
projected_features
.
shape
)
print
(
projected_features
.
shape
)
if
normalize
:
if
normalize
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment