Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e1df7597
Commit
e1df7597
authored
Jun 11, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Jun 11, 2021
Browse files
Apply stream buffer after the spatial convolution in (2+1)D mode.
PiperOrigin-RevId: 378923791
parent
e77956d6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
10 deletions
+54
-10
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+11
-2
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+21
-2
official/vision/beta/projects/movinet/modeling/movinet_model.py
...al/vision/beta/projects/movinet/modeling/movinet_model.py
+22
-6
No files found.
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
e1df7597
...
@@ -525,7 +525,6 @@ class Movinet(tf.keras.Model):
...
@@ -525,7 +525,6 @@ class Movinet(tf.keras.Model):
Returns:
Returns:
A dict mapping state names to state shapes.
A dict mapping state names to state shapes.
"""
"""
def
divide_resolution
(
shape
,
num_downsamples
):
def
divide_resolution
(
shape
,
num_downsamples
):
"""Downsamples the dimension to calculate strided convolution shape."""
"""Downsamples the dimension to calculate strided convolution shape."""
if
shape
is
None
:
if
shape
is
None
:
...
@@ -564,6 +563,12 @@ class Movinet(tf.keras.Model):
...
@@ -564,6 +563,12 @@ class Movinet(tf.keras.Model):
for
layer_idx
,
layer
in
enumerate
(
params
):
for
layer_idx
,
layer
in
enumerate
(
params
):
expand_filters
,
kernel_size
,
strides
=
layer
expand_filters
,
kernel_size
,
strides
=
layer
# If we use a 2D kernel, we apply spatial downsampling
# before the buffer.
if
(
tuple
(
strides
[
1
:
3
])
!=
(
1
,
1
)
and
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
if
kernel_size
[
0
]
>
1
:
if
kernel_size
[
0
]
>
1
:
states
[
f
'state/b
{
block_idx
}
/l
{
layer_idx
}
/stream_buffer'
]
=
(
states
[
f
'state/b
{
block_idx
}
/l
{
layer_idx
}
/stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
0
],
...
@@ -585,7 +590,11 @@ class Movinet(tf.keras.Model):
...
@@ -585,7 +590,11 @@ class Movinet(tf.keras.Model):
if
strides
[
1
]
!=
strides
[
2
]:
if
strides
[
1
]
!=
strides
[
2
]:
raise
ValueError
(
'Strides must match in the spatial dimensions, '
raise
ValueError
(
'Strides must match in the spatial dimensions, '
'got {}'
.
format
(
strides
))
'got {}'
.
format
(
strides
))
if
strides
[
1
]
!=
1
or
strides
[
2
]
!=
1
:
# If we use a 3D kernel, we apply spatial downsampling
# after the buffer.
if
(
tuple
(
strides
[
1
:
3
])
!=
(
1
,
1
)
and
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
num_downsamples
+=
1
elif
isinstance
(
block
,
HeadSpec
):
elif
isinstance
(
block
,
HeadSpec
):
states
[
'state/head/pool_buffer'
]
=
(
states
[
'state/head/pool_buffer'
]
=
(
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
e1df7597
...
@@ -633,9 +633,28 @@ class StreamConvBlock(ConvBlock):
...
@@ -633,9 +633,28 @@ class StreamConvBlock(ConvBlock):
states
=
dict
(
states
)
if
states
is
not
None
else
{}
states
=
dict
(
states
)
if
states
is
not
None
else
{}
x
=
inputs
x
=
inputs
# If we have no separate temporal conv, use the buffer before the 3D conv.
if
self
.
_conv_temporal
is
None
and
self
.
_stream_buffer
is
not
None
:
x
,
states
=
self
.
_stream_buffer
(
x
,
states
=
states
)
x
=
self
.
_conv
(
x
)
if
self
.
_batch_norm
is
not
None
:
x
=
self
.
_batch_norm
(
x
)
if
self
.
_activation_layer
is
not
None
:
x
=
self
.
_activation_layer
(
x
)
if
self
.
_conv_temporal
is
not
None
:
if
self
.
_stream_buffer
is
not
None
:
if
self
.
_stream_buffer
is
not
None
:
# If we have a separate temporal conv, use the buffer before the
# 1D conv instead (otherwise, we may waste computation on the 2D conv).
x
,
states
=
self
.
_stream_buffer
(
x
,
states
=
states
)
x
,
states
=
self
.
_stream_buffer
(
x
,
states
=
states
)
x
=
super
(
StreamConvBlock
,
self
).
call
(
x
)
x
=
self
.
_conv_temporal
(
x
)
if
self
.
_batch_norm_temporal
is
not
None
:
x
=
self
.
_batch_norm_temporal
(
x
)
if
self
.
_activation_layer
is
not
None
:
x
=
self
.
_activation_layer
(
x
)
return
x
,
states
return
x
,
states
...
...
official/vision/beta/projects/movinet/modeling/movinet_model.py
View file @
e1df7597
...
@@ -115,15 +115,31 @@ class MovinetClassifier(tf.keras.Model):
...
@@ -115,15 +115,31 @@ class MovinetClassifier(tf.keras.Model):
inputs
=
{
**
states
,
'image'
:
image
}
inputs
=
{
**
states
,
'image'
:
image
}
if
backbone
.
use_external_states
:
if
backbone
.
use_external_states
:
before_states
=
set
(
states
)
before_states
=
states
endpoints
,
states
=
backbone
(
inputs
)
endpoints
,
states
=
backbone
(
inputs
)
after_states
=
set
(
states
)
after_states
=
states
new_states
=
after_states
-
before_states
new_states
=
set
(
after_states
)
-
set
(
before_states
)
if
new_states
:
if
new_states
:
raise
AttributeError
(
'Expected input and output states to be the same. '
raise
ValueError
(
'Got extra states {}, expected {}'
.
format
(
'Expected input and output states to be the same. Got extra states '
new_states
,
before_states
))
'{}, expected {}'
.
format
(
new_states
,
set
(
before_states
)))
mismatched_shapes
=
{}
for
name
in
after_states
:
before_shape
=
before_states
[
name
].
shape
after_shape
=
after_states
[
name
].
shape
if
len
(
before_shape
)
!=
len
(
after_shape
):
mismatched_shapes
[
name
]
=
(
before_shape
,
after_shape
)
continue
for
before
,
after
in
zip
(
before_shape
,
after_shape
):
if
before
is
not
None
and
after
is
not
None
and
before
!=
after
:
mismatched_shapes
[
name
]
=
(
before_shape
,
after_shape
)
break
if
mismatched_shapes
:
raise
ValueError
(
'Got mismatched input and output state shapes: {}'
.
format
(
mismatched_shapes
))
else
:
else
:
endpoints
,
states
=
backbone
(
inputs
)
endpoints
,
states
=
backbone
(
inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment