Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
93e0022d
Commit
93e0022d
authored
Dec 10, 2018
by
Toby Boyd
Browse files
Merge branch 'cifar_keras' of github.com:tensorflow/models into cifar_keras
parents
2f2a04c7
c58a3b44
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
262 additions
and
5 deletions
+262
-5
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+18
-5
official/resnet/keras/resnet_model_tpu.py
official/resnet/keras/resnet_model_tpu.py
+244
-0
No files found.
official/resnet/keras/keras_imagenet_main.py
View file @
93e0022d
...
@@ -29,6 +29,7 @@ from official.resnet import imagenet_main
...
@@ -29,6 +29,7 @@ from official.resnet import imagenet_main
from
official.resnet
import
imagenet_preprocessing
from
official.resnet
import
imagenet_preprocessing
from
official.resnet
import
resnet_run_loop
from
official.resnet
import
resnet_run_loop
from
official.resnet.keras
import
keras_resnet_model
from
official.resnet.keras
import
keras_resnet_model
from
official.resnet.keras
import
resnet_model_tpu
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
...
@@ -81,7 +82,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -81,7 +82,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
]
]
BASE_LEARNING_RATE
=
0.1
28
BASE_LEARNING_RATE
=
0.1
# This matches Jing's version.
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
"""Handles linear scaling rule, gradual warmup, and LR decay.
"""Handles linear scaling rule, gradual warmup, and LR decay.
...
@@ -189,6 +190,9 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -189,6 +190,9 @@ def run_imagenet_with_keras(flags_obj):
Raises:
Raises:
ValueError: If fp16 is passed as it is not currently supported.
ValueError: If fp16 is passed as it is not currently supported.
"""
"""
if
flags_obj
.
enable_eager
:
tf
.
enable_eager_execution
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
if
dtype
==
'fp16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
...
@@ -247,13 +251,16 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -247,13 +251,16 @@ def run_imagenet_with_keras(flags_obj):
# learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256
# learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
num_gpus
=
flags_obj
.
num_gpus
)
num_gpus
=
flags_obj
.
num_gpus
)
if
flags_obj
.
use_tpu_model
:
model
=
resnet_model_tpu
.
ResNet50
(
num_classes
=
imagenet_main
.
_NUM_CLASSES
)
else
:
model
=
keras_resnet_model
.
ResNet50
(
classes
=
imagenet_main
.
_NUM_CLASSES
,
model
=
keras_resnet_model
.
ResNet50
(
classes
=
imagenet_main
.
_NUM_CLASSES
,
weights
=
None
)
weights
=
None
)
loss
=
'categorical_crossentropy'
loss
=
'categorical_crossentropy'
accuracy
=
'categorical_accuracy'
accuracy
=
'categorical_accuracy'
...
@@ -268,7 +275,7 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -268,7 +275,7 @@ def run_imagenet_with_keras(flags_obj):
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
flags_obj
.
model_dir
)
log_dir
=
flags_obj
.
model_dir
)
#
update_freq="batch") # Add this if want per batch logging.
#update_freq="batch") # Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule
,
learning_rate_schedule
,
...
@@ -295,6 +302,10 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -295,6 +302,10 @@ def run_imagenet_with_keras(flags_obj):
verbose
=
1
)
verbose
=
1
)
print
(
'Test loss:'
,
eval_output
[
0
])
print
(
'Test loss:'
,
eval_output
[
0
])
def
define_keras_imagenet_flags
():
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
def
main
(
_
):
def
main
(
_
):
with
logger
.
benchmark_context
(
flags
.
FLAGS
):
with
logger
.
benchmark_context
(
flags
.
FLAGS
):
run_imagenet_with_keras
(
flags
.
FLAGS
)
run_imagenet_with_keras
(
flags
.
FLAGS
)
...
@@ -302,5 +313,7 @@ def main(_):
...
@@ -302,5 +313,7 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
define_keras_imagenet_flags
()
imagenet_main
.
define_imagenet_flags
()
imagenet_main
.
define_imagenet_flags
()
flags
.
DEFINE_boolean
(
name
=
'use_tpu_model'
,
default
=
False
,
help
=
'Use resnet model from tpu.'
)
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/resnet/keras/resnet_model_tpu.py
0 → 100644
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet50 model for Keras.
Adapted from tf.keras.applications.resnet50.ResNet50().
Related papers/blogs:
- https://arxiv.org/abs/1512.03385
- https://arxiv.org/pdf/1603.05027v2.pdf
- http://torch.ch/blog/2016/02/04/resnets.html
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
warnings
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
layers
from
tensorflow.python.keras
import
models
from
tensorflow.python.keras
import
regularizers
from
tensorflow.python.keras
import
utils
L2_WEIGHT_DECAY
=
1e-4
BATCH_NORM_DECAY
=
0.9
BATCH_NORM_EPSILON
=
1e-5
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
x
=
layers
.
add
([
x
,
input_tensor
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
conv_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
strides
=
(
2
,
2
)):
"""A block that has a conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
# Returns
Output tensor for the block.
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'1'
)(
shortcut
)
x
=
layers
.
add
([
x
,
shortcut
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
ResNet50
(
num_classes
):
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
Returns:
A Keras model instance.
"""
# Determine proper input shape
if
backend
.
image_data_format
()
==
'channels_first'
:
input_shape
=
(
3
,
224
,
224
)
bn_axis
=
1
else
:
input_shape
=
(
224
,
224
,
3
)
bn_axis
=
3
img_input
=
layers
.
Input
(
shape
=
input_shape
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
img_input
)
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
strides
=
(
2
,
2
),
padding
=
'valid'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'conv1'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
'bn_conv1'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
1
,
1
),
name
=
'pool1_pad'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
))(
x
)
x
=
conv_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
))
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
)
x
=
layers
.
GlobalAveragePooling2D
(
name
=
'avg_pool'
)(
x
)
x
=
layers
.
Dense
(
num_classes
,
activation
=
'softmax'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'fc1000'
)(
x
)
# Create model.
return
models
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment