Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
81ad46bf
Commit
81ad46bf
authored
Aug 17, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Aug 17, 2021
Browse files
Fix MoViNet TF Lite state init by replacing '/' with '_' in state names.
PiperOrigin-RevId: 391332984
parent
cb62fdcc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
45 deletions
+46
-45
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+6
-6
official/vision/beta/projects/movinet/README.md
official/vision/beta/projects/movinet/README.md
+2
-2
official/vision/beta/projects/movinet/export_saved_model_test.py
...l/vision/beta/projects/movinet/export_saved_model_test.py
+2
-4
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+24
-21
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+2
-2
official/vision/beta/projects/movinet/modeling/movinet_test.py
...ial/vision/beta/projects/movinet/modeling/movinet_test.py
+10
-10
No files found.
official/vision/beta/modeling/layers/nn_layers.py
View file @
81ad46bf
...
@@ -425,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
...
@@ -425,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
self
.
_rezero
=
Scale
(
initializer
=
initializer
,
name
=
'rezero'
)
self
.
_rezero
=
Scale
(
initializer
=
initializer
,
name
=
'rezero'
)
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
self
.
_state_prefix
=
state_prefix
self
.
_state_prefix
=
state_prefix
self
.
_frame_count_name
=
f
'
{
state_prefix
}
/
pos_enc_frame_count'
self
.
_frame_count_name
=
f
'
{
state_prefix
}
_
pos_enc_frame_count'
def
get_config
(
self
):
def
get_config
(
self
):
"""Returns a dictionary containing the config used for initialization."""
"""Returns a dictionary containing the config used for initialization."""
...
@@ -523,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
...
@@ -523,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`.
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). Expected keys
layer, will overwrite the contents of the buffer(s). Expected keys
include `state_prefix + '
/
pos_enc_frame_count'`.
include `state_prefix + '
_
pos_enc_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
states. Returns just the output tensor otherwise.
...
@@ -587,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
...
@@ -587,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
self
.
_state_prefix
=
state_prefix
self
.
_state_prefix
=
state_prefix
self
.
_state_name
=
f
'
{
state_prefix
}
/
pool_buffer'
self
.
_state_name
=
f
'
{
state_prefix
}
_
pool_buffer'
self
.
_frame_count_name
=
f
'
{
state_prefix
}
/
pool_frame_count'
self
.
_frame_count_name
=
f
'
{
state_prefix
}
_
pool_frame_count'
def
get_config
(
self
):
def
get_config
(
self
):
"""Returns a dictionary containing the config used for initialization."""
"""Returns a dictionary containing the config used for initialization."""
...
@@ -611,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
...
@@ -611,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`.
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '
/
pool_buffer'` and
Expected keys include `state_prefix + '
__
pool_buffer'` and
`state_prefix + '
/
pool_frame_count'`.
`state_prefix + '
__
pool_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
states. Returns just the output tensor otherwise.
...
...
official/vision/beta/projects/movinet/README.md
View file @
81ad46bf
...
@@ -338,7 +338,7 @@ with the Python API:
...
@@ -338,7 +338,7 @@ with the Python API:
```
python
```
python
# Create the interpreter and signature runner
# Create the interpreter and signature runner
interpreter
=
tf
.
lite
.
Interpreter
(
'/tmp/movinet_a0_stream.tflite'
)
interpreter
=
tf
.
lite
.
Interpreter
(
'/tmp/movinet_a0_stream.tflite'
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
()
# Extract state names and create the initial (zero) states
# Extract state names and create the initial (zero) states
def
state_name
(
name
:
str
)
->
str
:
def
state_name
(
name
:
str
)
->
str
:
...
@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
...
@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states
=
init_states
states
=
init_states
for
clip
in
clips
:
for
clip
in
clips
:
# Input shape: [1, 1, 172, 172, 3]
# Input shape: [1, 1, 172, 172, 3]
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
states
=
outputs
```
```
...
...
official/vision/beta/projects/movinet/export_saved_model_test.py
View file @
81ad46bf
...
@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
self
.
assertAllClose
(
outputs
,
expected_outputs
,
1e-5
,
1e-5
)
self
.
assertAllClose
(
outputs
,
expected_outputs
,
1e-5
,
1e-5
)
def
test_movinet_export_a0_stream_with_tflite
(
self
):
def
test_movinet_export_a0_stream_with_tflite
(
self
):
self
.
skipTest
(
'b/195800800'
)
saved_model_path
=
self
.
get_temp_dir
()
saved_model_path
=
self
.
get_temp_dir
()
FLAGS
.
export_path
=
saved_model_path
FLAGS
.
export_path
=
saved_model_path
...
@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model
=
converter
.
convert
()
tflite_model
=
converter
.
convert
()
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
tflite_model
)
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
tflite_model
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
(
'serving_default'
)
def
state_name
(
name
:
str
)
->
str
:
def
state_name
(
name
:
str
)
->
str
:
return
name
[
len
(
'serving_default_'
):
-
len
(
':0'
)]
return
name
[
len
(
'serving_default_'
):
-
len
(
':0'
)]
...
@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states
=
init_states
states
=
init_states
for
clip
in
clips
:
for
clip
in
clips
:
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
states
=
outputs
...
...
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
81ad46bf
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
"""
import
dataclasses
import
math
import
math
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
...
@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
...
@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx
=
1
stochastic_depth_idx
=
1
for
block_idx
,
block
in
enumerate
(
self
.
_block_specs
):
for
block_idx
,
block
in
enumerate
(
self
.
_block_specs
):
if
isinstance
(
block
,
StemSpec
):
if
isinstance
(
block
,
StemSpec
):
x
,
states
=
movinet_layers
.
Stem
(
layer_obj
=
movinet_layers
.
Stem
(
block
.
filters
,
block
.
filters
,
block
.
kernel_size
,
block
.
kernel_size
,
block
.
strides
,
block
.
strides
,
...
@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
...
@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
stem'
,
state_prefix
=
'state
_
stem'
,
name
=
'stem'
)
(
name
=
'stem'
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'stem'
]
=
x
endpoints
[
'stem'
]
=
x
elif
isinstance
(
block
,
MovinetBlockSpec
):
elif
isinstance
(
block
,
MovinetBlockSpec
):
if
not
(
len
(
block
.
expand_filters
)
==
len
(
block
.
kernel_sizes
)
==
if
not
(
len
(
block
.
expand_filters
)
==
len
(
block
.
kernel_sizes
)
==
...
@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
...
@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self
.
_stochastic_depth_drop_rate
*
stochastic_depth_idx
/
self
.
_stochastic_depth_drop_rate
*
stochastic_depth_idx
/
num_layers
)
num_layers
)
expand_filters
,
kernel_size
,
strides
=
layer
expand_filters
,
kernel_size
,
strides
=
layer
name
=
f
'b
{
block_idx
-
1
}
/l
{
layer_idx
}
'
name
=
f
'b
lock
{
block_idx
-
1
}
_layer
{
layer_idx
}
'
x
,
states
=
movinet_layers
.
MovinetBlock
(
layer_obj
=
movinet_layers
.
MovinetBlock
(
block
.
base_filters
,
block
.
base_filters
,
expand_filters
,
expand_filters
,
kernel_size
=
kernel_size
,
kernel_size
=
kernel_size
,
...
@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
...
@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
f
'state/
{
name
}
'
,
state_prefix
=
f
'state_
{
name
}
'
,
name
=
name
)(
name
=
name
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
name
]
=
x
endpoints
[
name
]
=
x
stochastic_depth_idx
+=
1
stochastic_depth_idx
+=
1
elif
isinstance
(
block
,
HeadSpec
):
elif
isinstance
(
block
,
HeadSpec
):
x
,
states
=
movinet_layers
.
Head
(
layer_obj
=
movinet_layers
.
Head
(
project_filters
=
block
.
project_filters
,
project_filters
=
block
.
project_filters
,
conv_type
=
self
.
_conv_type
,
conv_type
=
self
.
_conv_type
,
activation
=
self
.
_activation
,
activation
=
self
.
_activation
,
...
@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
...
@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
head'
,
state_prefix
=
'state
_
head'
,
name
=
'head'
)
(
name
=
'head'
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'head'
]
=
x
endpoints
[
'head'
]
=
x
else
:
else
:
raise
ValueError
(
'Unknown block type {}'
.
format
(
block
))
raise
ValueError
(
'Unknown block type {}'
.
format
(
block
))
...
@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
...
@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for
block_idx
,
block
in
enumerate
(
block_specs
):
for
block_idx
,
block
in
enumerate
(
block_specs
):
if
isinstance
(
block
,
StemSpec
):
if
isinstance
(
block
,
StemSpec
):
if
block
.
kernel_size
[
0
]
>
1
:
if
block
.
kernel_size
[
0
]
>
1
:
states
[
'state
/
stem
/
stream_buffer'
]
=
(
states
[
'state
_
stem
_
stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
0
],
input_shape
[
1
],
input_shape
[
1
],
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
...
@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
num_downsamples
+=
1
prefix
=
f
'state_block
{
block_idx
}
_layer
{
layer_idx
}
'
if
kernel_size
[
0
]
>
1
:
if
kernel_size
[
0
]
>
1
:
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
stream_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
0
],
kernel_size
[
0
]
-
1
,
kernel_size
[
0
]
-
1
,
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
...
@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters
,
expand_filters
,
)
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
)
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_frame_count'
]
=
(
1
,)
states
[
f
'
{
prefix
}
_
pool_frame_count'
]
=
(
1
,)
if
use_positional_encoding
:
if
use_positional_encoding
:
name
=
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pos_enc_frame_count'
name
=
f
'
{
prefix
}
_
pos_enc_frame_count'
states
[
name
]
=
(
1
,)
states
[
name
]
=
(
1
,)
if
strides
[
1
]
!=
strides
[
2
]:
if
strides
[
1
]
!=
strides
[
2
]:
...
@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
...
@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
num_downsamples
+=
1
elif
isinstance
(
block
,
HeadSpec
):
elif
isinstance
(
block
,
HeadSpec
):
states
[
'state
/
head
/
pool_buffer'
]
=
(
states
[
'state
_
head
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
block
.
project_filters
,
input_shape
[
0
],
1
,
1
,
1
,
block
.
project_filters
,
)
)
states
[
'state
/
head
/
pool_frame_count'
]
=
(
1
,)
states
[
'state
_
head
_
pool_frame_count'
]
=
(
1
,)
return
states
return
states
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
81ad46bf
...
@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
...
@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
self
.
_state_prefix
=
state_prefix
self
.
_state_prefix
=
state_prefix
self
.
_state_name
=
f
'
{
state_prefix
}
/
stream_buffer'
self
.
_state_name
=
f
'
{
state_prefix
}
_
stream_buffer'
self
.
_buffer_size
=
buffer_size
self
.
_buffer_size
=
buffer_size
def
get_config
(
self
):
def
get_config
(
self
):
...
@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
...
@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor.
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '
/
stream_buffer'`.
Expected keys include `state_prefix + '
_
stream_buffer'`.
Returns:
Returns:
the output tensor and states
the output tensor and states
...
...
official/vision/beta/projects/movinet/modeling/movinet_test.py
View file @
81ad46bf
...
@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
states
=
network
(
inputs
)
endpoints
,
states
=
network
(
inputs
)
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
states
)
self
.
assertNotEmpty
(
states
)
...
@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
new_states
=
backbone
({
**
init_states
,
'image'
:
inputs
})
endpoints
,
new_states
=
backbone
({
**
init_states
,
'image'
:
inputs
})
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
init_states
)
self
.
assertNotEmpty
(
init_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment