Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d87fd224
Commit
d87fd224
authored
Mar 11, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 362396049
parent
d2a9e4c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
24 deletions
+32
-24
official/vision/beta/modeling/classification_model.py
official/vision/beta/modeling/classification_model.py
+25
-24
official/vision/beta/serving/image_classification_test.py
official/vision/beta/serving/image_classification_test.py
+7
-0
No files found.
official/vision/beta/modeling/classification_model.py
View file @
d87fd224
...
...
@@ -59,28 +59,10 @@ class ClassificationModel(tf.keras.Model):
skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed.
"""
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'num_classes'
:
num_classes
,
'input_specs'
:
input_specs
,
'dropout_rate'
:
dropout_rate
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'add_head_batch_norm'
:
add_head_batch_norm
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
}
self
.
_input_specs
=
input_specs
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_backbone
=
backbone
if
use_sync_bn
:
self
.
_
norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_
norm
=
tf
.
keras
.
layers
.
BatchNormalization
norm
=
tf
.
keras
.
layers
.
BatchNormalization
axis
=
-
1
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
else
1
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
...
...
@@ -88,18 +70,37 @@ class ClassificationModel(tf.keras.Model):
x
=
endpoints
[
max
(
endpoints
.
keys
())]
if
add_head_batch_norm
:
x
=
self
.
_
norm
(
axis
=
axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
norm
(
axis
=
axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
()(
x
)
if
not
skip_logits_layer
:
x
=
tf
.
keras
.
layers
.
Dropout
(
dropout_rate
)(
x
)
x
=
tf
.
keras
.
layers
.
Dense
(
num_classes
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
num_classes
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
bias_regularizer
=
bias_regularizer
)(
x
)
super
(
ClassificationModel
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
x
,
**
kwargs
)
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'num_classes'
:
num_classes
,
'input_specs'
:
input_specs
,
'dropout_rate'
:
dropout_rate
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'add_head_batch_norm'
:
add_head_batch_norm
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
}
self
.
_input_specs
=
input_specs
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_backbone
=
backbone
self
.
_norm
=
norm
@
property
def
checkpoint_items
(
self
):
...
...
official/vision/beta/serving/image_classification_test.py
View file @
d87fd224
...
...
@@ -74,6 +74,9 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module
.
model
.
test_trackable
=
tf
.
keras
.
layers
.
InputLayer
(
input_shape
=
(
4
,))
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
...
...
@@ -96,6 +99,10 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
shape
=
[
224
,
224
,
3
],
dtype
=
tf
.
float32
)))
expected_output
=
module
.
model
(
processed_images
,
training
=
False
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertTrue
(
hasattr
(
imported
.
model
,
'test_trackable'
))
self
.
assertAllClose
(
out
[
'outputs'
].
numpy
(),
expected_output
.
numpy
())
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment