Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
85a6db17
Commit
85a6db17
authored
Apr 14, 2021
by
Scott Zhu
Committed by
A. Unique TensorFlower
Apr 14, 2021
Browse files
Internal change
PiperOrigin-RevId: 368573031
parent
d0879611
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
12 deletions
+9
-12
official/vision/image_classification/resnet/resnet_model.py
official/vision/image_classification/resnet/resnet_model.py
+9
-12
No files found.
official/vision/image_classification/resnet/resnet_model.py
View file @
85a6db17
...
@@ -28,18 +28,14 @@ from __future__ import division
...
@@ -28,18 +28,14 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
initializers
from
tensorflow.python.keras
import
models
from
tensorflow.python.keras
import
regularizers
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
,
l2_weight_decay
=
1e-4
):
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
,
l2_weight_decay
=
1e-4
):
return
regularizers
.
l2
(
l2_weight_decay
)
if
use_l2_regularizer
else
None
return
tf
.
keras
.
regularizers
.
L2
(
l2_weight_decay
)
if
use_l2_regularizer
else
None
def
identity_block
(
input_tensor
,
def
identity_block
(
input_tensor
,
...
@@ -66,7 +62,7 @@ def identity_block(input_tensor,
...
@@ -66,7 +62,7 @@ def identity_block(input_tensor,
Output tensor for the block.
Output tensor for the block.
"""
"""
filters1
,
filters2
,
filters3
=
filters
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
bn_axis
=
3
else
:
else
:
bn_axis
=
1
bn_axis
=
1
...
@@ -154,7 +150,7 @@ def conv_block(input_tensor,
...
@@ -154,7 +150,7 @@ def conv_block(input_tensor,
Output tensor for the block.
Output tensor for the block.
"""
"""
filters1
,
filters2
,
filters3
=
filters
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
bn_axis
=
3
else
:
else
:
bn_axis
=
1
bn_axis
=
1
...
@@ -253,7 +249,7 @@ def resnet50(num_classes,
...
@@ -253,7 +249,7 @@ def resnet50(num_classes,
# Hub image modules expect inputs in the range [0, 1]. This rescales these
# Hub image modules expect inputs in the range [0, 1]. This rescales these
# inputs to the range expected by the trained model.
# inputs to the range expected by the trained model.
x
=
layers
.
Lambda
(
x
=
layers
.
Lambda
(
lambda
x
:
x
*
255.0
-
backend
.
constant
(
lambda
x
:
x
*
255.0
-
tf
.
keras
.
backend
.
constant
(
# pylint: disable=g-long-lambda
imagenet_preprocessing
.
CHANNEL_MEANS
,
imagenet_preprocessing
.
CHANNEL_MEANS
,
shape
=
[
1
,
1
,
3
],
shape
=
[
1
,
1
,
3
],
dtype
=
x
.
dtype
),
dtype
=
x
.
dtype
),
...
@@ -262,7 +258,7 @@ def resnet50(num_classes,
...
@@ -262,7 +258,7 @@ def resnet50(num_classes,
else
:
else
:
x
=
img_input
x
=
img_input
if
backend
.
image_data_format
()
==
'channels_first'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_first'
:
x
=
layers
.
Permute
((
3
,
1
,
2
))(
x
)
x
=
layers
.
Permute
((
3
,
1
,
2
))(
x
)
bn_axis
=
1
bn_axis
=
1
else
:
# channels_last
else
:
# channels_last
...
@@ -315,7 +311,8 @@ def resnet50(num_classes,
...
@@ -315,7 +311,8 @@ def resnet50(num_classes,
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
Dense
(
x
=
layers
.
Dense
(
num_classes
,
num_classes
,
kernel_initializer
=
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_initializer
=
tf
.
compat
.
v1
.
keras
.
initializers
.
random_normal
(
stddev
=
0.01
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
bias_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
bias_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
'fc1000'
)(
name
=
'fc1000'
)(
...
@@ -326,4 +323,4 @@ def resnet50(num_classes,
...
@@ -326,4 +323,4 @@ def resnet50(num_classes,
x
=
layers
.
Activation
(
'softmax'
,
dtype
=
'float32'
)(
x
)
x
=
layers
.
Activation
(
'softmax'
,
dtype
=
'float32'
)(
x
)
# Create model.
# Create model.
return
model
s
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
return
tf
.
kera
s
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment