Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
889dc12a
Commit
889dc12a
authored
Jun 07, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 453477038
parent
ed0d9c71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
6 deletions
+22
-6
official/vision/modeling/backbones/spinenet.py
official/vision/modeling/backbones/spinenet.py
+10
-6
official/vision/modeling/backbones/spinenet_test.py
official/vision/modeling/backbones/spinenet_test.py
+12
-0
No files found.
official/vision/modeling/backbones/spinenet.py
View file @
889dc12a
...
@@ -199,15 +199,11 @@ class SpineNet(tf.keras.Model):
...
@@ -199,15 +199,11 @@ class SpineNet(tf.keras.Model):
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_norm_epsilon
=
norm_epsilon
if
activation
==
'relu'
:
self
.
_activation_fn
=
tf
.
nn
.
relu
elif
activation
==
'swish'
:
self
.
_activation_fn
=
tf
.
nn
.
swish
else
:
raise
ValueError
(
'Activation {} not implemented.'
.
format
(
activation
))
self
.
_init_block_fn
=
'bottleneck'
self
.
_init_block_fn
=
'bottleneck'
self
.
_num_init_blocks
=
2
self
.
_num_init_blocks
=
2
self
.
_set_activation_fn
(
activation
)
if
use_sync_bn
:
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
else
:
...
@@ -232,6 +228,14 @@ class SpineNet(tf.keras.Model):
...
@@ -232,6 +228,14 @@ class SpineNet(tf.keras.Model):
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
SpineNet
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
super
(
SpineNet
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
def
_set_activation_fn
(
self
,
activation
):
if
activation
==
'relu'
:
self
.
_activation_fn
=
tf
.
nn
.
relu
elif
activation
==
'swish'
:
self
.
_activation_fn
=
tf
.
nn
.
swish
else
:
raise
ValueError
(
'Activation {} not implemented.'
.
format
(
activation
))
def
_block_group
(
self
,
def
_block_group
(
self
,
inputs
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
,
filters
:
int
,
filters
:
int
,
...
...
official/vision/modeling/backbones/spinenet_test.py
View file @
889dc12a
...
@@ -122,6 +122,18 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -122,6 +122,18 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old.
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
@
parameterized
.
parameters
(
(
'relu'
,
tf
.
nn
.
relu
),
(
'swish'
,
tf
.
nn
.
swish
)
)
def
test_activation
(
self
,
activation
,
activation_fn
):
model
=
spinenet
.
SpineNet
(
activation
=
activation
)
self
.
assertEqual
(
model
.
_activation_fn
,
activation_fn
)
def
test_invalid_activation_raises_valurerror
(
self
):
with
self
.
assertRaises
(
ValueError
):
spinenet
.
SpineNet
(
activation
=
'invalid_activation_name'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment