Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
842e5a3e
Unverified
Commit
842e5a3e
authored
Dec 27, 2018
by
Shining Sun
Committed by
GitHub
Dec 27, 2018
Browse files
Merge pull request #5928 from tensorflow/cifar_keras_refactor
Cifar keras refactor
parents
df122b10
03c35ec6
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1332 additions
and
32 deletions
+1332
-32
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+16
-16
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+11
-11
official/resnet/keras/__init__.py
official/resnet/keras/__init__.py
+0
-0
official/resnet/keras/keras_cifar_benchmark.py
official/resnet/keras/keras_cifar_benchmark.py
+89
-0
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+193
-0
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+212
-0
official/resnet/keras/keras_common_test.py
official/resnet/keras/keras_common_test.py
+74
-0
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+180
-0
official/resnet/keras/resnet_cifar_model.py
official/resnet/keras/resnet_cifar_model.py
+285
-0
official/resnet/keras/resnet_model.py
official/resnet/keras/resnet_model.py
+246
-0
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+4
-1
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+22
-4
No files found.
official/resnet/cifar10_main.py
View file @
842e5a3e
...
...
@@ -29,17 +29,17 @@ from official.utils.logs import logger
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
_
HEIGHT
=
32
_
WIDTH
=
32
_
NUM_CHANNELS
=
3
_DEFAULT_IMAGE_BYTES
=
_
HEIGHT
*
_
WIDTH
*
_
NUM_CHANNELS
HEIGHT
=
32
WIDTH
=
32
NUM_CHANNELS
=
3
_DEFAULT_IMAGE_BYTES
=
HEIGHT
*
WIDTH
*
NUM_CHANNELS
# The record is the image plus a one-byte label
_RECORD_BYTES
=
_DEFAULT_IMAGE_BYTES
+
1
_
NUM_CLASSES
=
10
NUM_CLASSES
=
10
_NUM_DATA_FILES
=
5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
_
NUM_IMAGES
=
{
NUM_IMAGES
=
{
'train'
:
50000
,
'validation'
:
10000
,
}
...
...
@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major
=
tf
.
reshape
(
record_vector
[
1
:
_RECORD_BYTES
],
[
_
NUM_CHANNELS
,
_
HEIGHT
,
_
WIDTH
])
[
NUM_CHANNELS
,
HEIGHT
,
WIDTH
])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
...
...
@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
if
is_training
:
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_
HEIGHT
+
8
,
_
WIDTH
+
8
)
image
,
HEIGHT
+
8
,
WIDTH
+
8
)
# Randomly crop a [
_
HEIGHT,
_
WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
# Randomly crop a [HEIGHT, WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
...
...
@@ -134,7 +134,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dataset
=
dataset
,
is_training
=
is_training
,
batch_size
=
batch_size
,
shuffle_buffer
=
_
NUM_IMAGES
[
'train'
],
shuffle_buffer
=
NUM_IMAGES
[
'train'
],
parse_record_fn
=
parse_record_fn
,
num_epochs
=
num_epochs
,
dtype
=
dtype
,
...
...
@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
dtype
=
dtype
)
HEIGHT
,
WIDTH
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
...
...
@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
class
Cifar10Model
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for CIFAR-10 data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for CIFAR-10 data.
...
...
@@ -196,11 +196,11 @@ class Cifar10Model(resnet_model.Model):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for CIFAR-10."""
features
=
tf
.
reshape
(
features
,
[
-
1
,
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
features
=
tf
.
reshape
(
features
,
[
-
1
,
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
num_images
=
_
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
91
,
136
,
182
],
num_images
=
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
91
,
136
,
182
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
...
...
@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
input_fn
)
resnet_run_loop
.
resnet_main
(
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
shape
=
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
def
main
(
_
):
...
...
official/resnet/imagenet_main.py
View file @
842e5a3e
...
...
@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
_
DEFAULT_IMAGE_SIZE
=
224
_
NUM_CHANNELS
=
3
_
NUM_CLASSES
=
1001
DEFAULT_IMAGE_SIZE
=
224
NUM_CHANNELS
=
3
NUM_CLASSES
=
1001
_
NUM_IMAGES
=
{
NUM_IMAGES
=
{
'train'
:
1281167
,
'validation'
:
50000
,
}
...
...
@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
image
=
imagenet_preprocessing
.
preprocess_image
(
image_buffer
=
image_buffer
,
bbox
=
bbox
,
output_height
=
_
DEFAULT_IMAGE_SIZE
,
output_width
=
_
DEFAULT_IMAGE_SIZE
,
num_channels
=
_
NUM_CHANNELS
,
output_height
=
DEFAULT_IMAGE_SIZE
,
output_width
=
DEFAULT_IMAGE_SIZE
,
num_channels
=
NUM_CHANNELS
,
is_training
=
is_training
)
image
=
tf
.
cast
(
image
,
dtype
)
...
...
@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
...
...
@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
class
ImagenetModel
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for Imagenet data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for Imagenet data.
...
...
@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
num_images
=
_
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
num_images
=
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
],
warmup
=
warmup
,
base_lr
=
base_lr
)
return
resnet_run_loop
.
resnet_model_fn
(
...
...
@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
resnet_run_loop
.
resnet_main
(
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
])
shape
=
[
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
])
def
main
(
_
):
...
...
official/resnet/keras/__init__.py
0 → 100644
View file @
842e5a3e
official/resnet/keras/keras_cifar_benchmark.py
0 → 100644
View file @
842e5a3e
"""Executes Keras benchmarks and accuracy tests."""
from
__future__
import
print_function
import
os
from
absl
import
flags
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
as
cifar_main
import
official.resnet.keras.keras_cifar_main
as
keras_cifar_main
import
official.resnet.keras.keras_common
as
keras_common
DATA_DIR
=
'/data/cifar10_data/'
class
KerasCifar10BenchmarkTests
(
object
):
"""Benchmarks and accuracy tests for KerasCifar10."""
local_flags
=
None
def
__init__
(
self
,
output_dir
=
None
):
self
.
oss_report_object
=
None
self
.
output_dir
=
output_dir
def
keras_resnet56_1_gpu
(
self
):
"""Test keras based model with Keras fit and distribution strategies."""
self
.
_setup
()
flags
.
FLAGS
.
num_gpus
=
1
flags
.
FLAGS
.
data_dir
=
DATA_DIR
flags
.
FLAGS
.
batch_size
=
128
flags
.
FLAGS
.
train_epochs
=
182
flags
.
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'keras_resnet56_1_gpu'
)
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
def
keras_resnet56_4_gpu
(
self
):
"""Test keras based model with Keras fit and distribution strategies."""
self
.
_setup
()
flags
.
FLAGS
.
num_gpus
=
4
flags
.
FLAGS
.
data_dir
=
self
.
_get_model_dir
(
'keras_resnet56_4_gpu'
)
flags
.
FLAGS
.
batch_size
=
128
flags
.
FLAGS
.
train_epochs
=
182
flags
.
FLAGS
.
model_dir
=
''
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
def
keras_resnet56_no_dist_strat_1_gpu
(
self
):
"""Test keras based model with Keras fit but not distribution strategies."""
self
.
_setup
()
flags
.
FLAGS
.
turn_off_distribution_strategy
=
True
flags
.
FLAGS
.
num_gpus
=
1
flags
.
FLAGS
.
data_dir
=
DATA_DIR
flags
.
FLAGS
.
batch_size
=
128
flags
.
FLAGS
.
train_epochs
=
182
flags
.
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'keras_resnet56_no_dist_strat_1_gpu'
)
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
def
_fill_report_object
(
self
,
stats
):
if
self
.
oss_report_object
:
self
.
oss_report_object
.
top_1
=
stats
[
'accuracy_top_1'
]
self
.
oss_report_object
.
add_other_quality
(
stats
[
'training_accuracy_top_1'
],
'top_1_train_accuracy'
)
else
:
raise
ValueError
(
'oss_report_object has not been set.'
)
def
_get_model_dir
(
self
,
folder_name
):
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
"""Setups up and resets flags before each test."""
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
if
KerasCifar10BenchmarkTests
.
local_flags
is
None
:
keras_common
.
define_keras_flags
()
cifar_main
.
define_cifar_flags
()
# Loads flags to get defaults to then override.
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
KerasCifar10BenchmarkTests
.
local_flags
=
saved_flag_values
return
flagsaver
.
restore_flag_values
(
KerasCifar10BenchmarkTests
.
local_flags
)
official/resnet/keras/keras_cifar_main.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the Cifar-10 dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
as
cifar_main
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_cifar_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
0.1
,
91
),
(
0.01
,
136
),
(
0.001
,
182
)
]
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
"""Handles linear scaling rule and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_learning_rate
=
keras_common
.
BASE_LEARNING_RATE
*
batch_size
/
128
learning_rate
=
initial_learning_rate
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
current_epoch
>=
start_epoch
:
learning_rate
=
initial_learning_rate
*
mult
else
:
break
return
learning_rate
def
parse_record_keras
(
raw_record
,
is_training
,
dtype
):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to one hot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image
,
label
=
cifar_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
label
=
tf
.
sparse_to_dense
(
label
,
(
cifar_main
.
NUM_CLASSES
,),
1
)
return
image
,
label
def
run
(
flags_obj
):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
if
flags_obj
.
enable_eager
:
tf
.
enable_eager_execution
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
'value(fp32).'
)
per_device_batch_size
=
distribution_utils
.
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
))
if
flags_obj
.
use_synthetic_data
:
input_fn
=
keras_common
.
get_synth_input_fn
(
height
=
cifar_main
.
HEIGHT
,
width
=
cifar_main
.
WIDTH
,
num_channels
=
cifar_main
.
NUM_CHANNELS
,
num_classes
=
cifar_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
else
:
input_fn
=
cifar_main
.
input_fn
train_input_dataset
=
input_fn
(
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
)
eval_input_dataset
=
input_fn
(
is_training
=
False
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
)
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
turn_off_distribution_strategy
)
model
=
resnet_cifar_model
.
resnet56
(
classes
=
cifar_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
distribute
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
,
cifar_main
.
NUM_IMAGES
[
'train'
])
train_steps
=
cifar_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_epochs
=
flags_obj
.
train_epochs
if
flags_obj
.
train_steps
:
train_steps
=
min
(
flags_obj
.
train_steps
,
train_steps
)
train_epochs
=
1
num_eval_steps
=
(
cifar_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
validation_data
=
eval_input_dataset
if
flags_obj
.
skip_eval
:
num_eval_steps
=
None
validation_data
=
None
history
=
model
.
fit
(
train_input_dataset
,
epochs
=
train_epochs
,
steps_per_epoch
=
train_steps
,
callbacks
=
[
time_callback
,
lr_callback
,
tensorboard_callback
],
validation_steps
=
num_eval_steps
,
validation_data
=
validation_data
,
verbose
=
1
)
eval_output
=
None
if
not
flags_obj
.
skip_eval
:
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
steps
=
num_eval_steps
,
verbose
=
1
)
stats
=
keras_common
.
build_stats
(
history
,
eval_output
)
return
stats
def
main
(
_
):
with
logger
.
benchmark_context
(
flags
.
FLAGS
):
run
(
flags
.
FLAGS
)
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
cifar_main
.
define_cifar_flags
()
keras_common
.
define_keras_flags
()
absl_app
.
run
(
main
)
official/resnet/keras/keras_common.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions and classes used by both keras cifar and imagenet."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
numpy
as
np
# pylint: disable=g-bad-import-order
from
absl
import
flags
import
tensorflow
as
tf
from
tensorflow.python.keras.optimizer_v2
import
(
gradient_descent
as
gradient_descent_v2
)
FLAGS
=
flags
.
FLAGS
BASE_LEARNING_RATE
=
0.1
# This matches Jing's version.
TRAIN_TOP_1
=
'training_accuracy_top_1'
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
):
"""Callback for logging performance (# image/second).
Args:
batch_size: Total batch size.
"""
self
.
_batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_steps
=
100
def
on_train_begin
(
self
,
logs
=
None
):
self
.
record_batch
=
True
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
self
.
record_batch
:
self
.
start_time
=
time
.
time
()
self
.
record_batch
=
False
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
%
self
.
log_steps
==
0
:
elapsed_time
=
time
.
time
()
-
self
.
start_time
examples_per_second
=
(
self
.
_batch_size
*
self
.
log_steps
)
/
elapsed_time
self
.
record_batch
=
True
# TODO(anjalisridhar): add timestamp as well.
if
batch
!=
0
:
tf
.
logging
.
info
(
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'images_per_second': %f}"
%
(
batch
,
elapsed_time
,
examples_per_second
))
class
LearningRateBatchScheduler
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Args:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def
__init__
(
self
,
schedule
,
batch_size
,
num_images
):
super
(
LearningRateBatchScheduler
,
self
).
__init__
()
self
.
schedule
=
schedule
self
.
batches_per_epoch
=
num_images
/
batch_size
self
.
batch_size
=
batch_size
self
.
epochs
=
-
1
self
.
prev_lr
=
-
1
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
if
not
hasattr
(
self
.
model
.
optimizer
,
'learning_rate'
):
raise
ValueError
(
'Optimizer must have a "learning_rate" attribute.'
)
self
.
epochs
+=
1
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
"""Executes before step begins."""
lr
=
self
.
schedule
(
self
.
epochs
,
batch
,
self
.
batches_per_epoch
,
self
.
batch_size
)
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
raise
ValueError
(
'The output of the "schedule" function should be float.'
)
if
lr
!=
self
.
prev_lr
:
self
.
model
.
optimizer
.
learning_rate
=
lr
# lr should be a float here
self
.
prev_lr
=
lr
tf
.
logging
.
debug
(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.'
,
self
.
epochs
,
batch
,
lr
)
def
get_optimizer
():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
return
gradient_descent_v2
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
def
get_callbacks
(
learning_rate_schedule_fn
,
num_images
):
"""Returns common callbacks."""
time_callback
=
TimeHistory
(
FLAGS
.
batch_size
)
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
FLAGS
.
model_dir
)
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule_fn
,
batch_size
=
FLAGS
.
batch_size
,
num_images
=
num_images
)
return
time_callback
,
tensorboard_callback
,
lr_callback
def
build_stats
(
history
,
eval_output
):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
Returns:
Dictionary of normalized results.
"""
stats
=
{}
if
eval_output
:
stats
[
'accuracy_top_1'
]
=
eval_output
[
1
].
item
()
stats
[
'eval_loss'
]
=
eval_output
[
0
].
item
()
if
history
and
history
.
history
:
train_hist
=
history
.
history
# Gets final loss from training.
stats
[
'loss'
]
=
train_hist
[
'loss'
][
-
1
].
item
()
# Gets top_1 training accuracy.
if
'categorical_accuracy'
in
train_hist
:
stats
[
TRAIN_TOP_1
]
=
train_hist
[
'categorical_accuracy'
][
-
1
].
item
()
elif
'sparse_categorical_accuracy'
in
train_hist
:
stats
[
TRAIN_TOP_1
]
=
train_hist
[
'sparse_categorical_accuracy'
][
-
1
].
item
()
return
stats
def
define_keras_flags
():
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_boolean
(
name
=
'skip_eval'
,
default
=
False
,
help
=
'Skip evaluation?'
)
flags
.
DEFINE_integer
(
name
=
'train_steps'
,
default
=
None
,
help
=
'The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.'
)
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
dtype
=
tf
.
float32
):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tuning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs
=
tf
.
truncated_normal
(
[
batch_size
]
+
[
height
,
width
,
num_channels
],
dtype
=
dtype
,
mean
=
127
,
stddev
=
60
,
name
=
'synthetic_inputs'
)
labels
=
tf
.
random_uniform
(
[
batch_size
]
+
[
1
],
minval
=
0
,
maxval
=
num_classes
-
1
,
dtype
=
tf
.
int32
,
name
=
'synthetic_labels'
)
data
=
tf
.
data
.
Dataset
.
from_tensors
((
inputs
,
labels
)).
repeat
()
data
=
data
.
prefetch
(
buffer_size
=
tf
.
contrib
.
data
.
AUTOTUNE
)
return
data
return
input_fn
official/resnet/keras/keras_common_test.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
from
__future__
import
absolute_import
from
__future__
import
print_function
from
mock
import
Mock
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet.keras
import
keras_common
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
ERROR
)
class
KerasCommonTests
(
tf
.
test
.
TestCase
):
"""Tests for keras_common."""
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasCommonTests
,
cls
).
setUpClass
()
def
test_build_stats
(
self
):
history
=
self
.
_build_history
(
1.145
,
cat_accuracy
=
.
99988
)
eval_output
=
self
.
_build_eval_output
(.
56432111
,
5.990
)
stats
=
keras_common
.
build_stats
(
history
,
eval_output
)
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
self
.
assertEqual
(.
56432111
,
stats
[
'accuracy_top_1'
])
self
.
assertEqual
(
5.990
,
stats
[
'eval_loss'
])
def
test_build_stats_sparse
(
self
):
history
=
self
.
_build_history
(
1.145
,
cat_accuracy_sparse
=
.
99988
)
eval_output
=
self
.
_build_eval_output
(.
928
,
1.9844
)
stats
=
keras_common
.
build_stats
(
history
,
eval_output
)
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
self
.
assertEqual
(.
928
,
stats
[
'accuracy_top_1'
])
self
.
assertEqual
(
1.9844
,
stats
[
'eval_loss'
])
def
_build_history
(
self
,
loss
,
cat_accuracy
=
None
,
cat_accuracy_sparse
=
None
):
history_p
=
Mock
()
history
=
{}
history_p
.
history
=
history
history
[
'loss'
]
=
[
np
.
float64
(
loss
)]
if
cat_accuracy
:
history
[
'categorical_accuracy'
]
=
[
np
.
float64
(
cat_accuracy
)]
if
cat_accuracy_sparse
:
history
[
'sparse_categorical_accuracy'
]
=
[
np
.
float64
(
cat_accuracy_sparse
)]
return
history_p
def
_build_eval_output
(
self
,
top_1
,
eval_loss
):
eval_output
=
[
np
.
float64
(
eval_loss
),
np
.
float64
(
top_1
)]
return
eval_output
official/resnet/keras/keras_imagenet_main.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
]
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr
=
keras_common
.
BASE_LEARNING_RATE
*
batch_size
/
256
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
if
epoch
<
warmup_end_epoch
:
# Learning rate increases linearly per step.
return
initial_lr
*
warmup_lr_multiplier
*
epoch
/
warmup_end_epoch
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
epoch
>=
start_epoch
:
learning_rate
=
initial_lr
*
mult
else
:
break
return
learning_rate
def
parse_record_keras
(
raw_record
,
is_training
,
dtype
):
"""Adjust the shape of label."""
image
,
label
=
imagenet_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label
=
tf
.
cast
(
tf
.
cast
(
tf
.
reshape
(
label
,
shape
=
[
1
]),
dtype
=
tf
.
int32
)
-
1
,
dtype
=
tf
.
float32
)
return
image
,
label
def
run
(
flags_obj
):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
if
flags_obj
.
enable_eager
:
tf
.
enable_eager_execution
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
'value(fp32).'
)
per_device_batch_size
=
distribution_utils
.
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
))
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
input_fn
=
keras_common
.
get_synth_input_fn
(
height
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
else
:
input_fn
=
imagenet_main
.
input_fn
train_input_dataset
=
input_fn
(
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
)
eval_input_dataset
=
input_fn
(
is_training
=
False
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
parse_record_keras
)
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
turn_off_distribution_strategy
)
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
[
'sparse_categorical_accuracy'
],
distribute
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
,
imagenet_main
.
NUM_IMAGES
[
'train'
])
train_steps
=
imagenet_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_epochs
=
flags_obj
.
train_epochs
if
flags_obj
.
train_steps
:
train_steps
=
min
(
flags_obj
.
train_steps
,
train_steps
)
train_epochs
=
1
num_eval_steps
=
(
imagenet_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
validation_data
=
eval_input_dataset
if
flags_obj
.
skip_eval
:
num_eval_steps
=
None
validation_data
=
None
model
.
fit
(
train_input_dataset
,
epochs
=
train_epochs
,
steps_per_epoch
=
train_steps
,
callbacks
=
[
time_callback
,
lr_callback
,
tensorboard_callback
],
validation_steps
=
num_eval_steps
,
validation_data
=
validation_data
,
verbose
=
1
)
if
not
flags_obj
.
skip_eval
:
model
.
evaluate
(
eval_input_dataset
,
steps
=
num_eval_steps
,
verbose
=
1
)
def
main
(
_
):
with
logger
.
benchmark_context
(
flags
.
FLAGS
):
run
(
flags
.
FLAGS
)
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
imagenet_main
.
define_imagenet_flags
()
keras_common
.
define_keras_flags
()
absl_app
.
run
(
main
)
official/resnet/keras/resnet_cifar_model.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
layers
BATCH_NORM_DECAY
=
0.997
BATCH_NORM_EPSILON
=
1e-5
L2_WEIGHT_DECAY
=
2e-4
def
identity_building_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
training
=
None
):
"""The identity block is the block that has no conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
"""
filters1
,
filters2
=
filters
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters1
,
kernel_size
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
bn_name_base
+
'2a'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
x
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
Activation
(
'relu'
)(
x
)
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
bn_name_base
+
'2b'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
x
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
add
([
x
,
input_tensor
])
x
=
tf
.
keras
.
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
conv_building_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
strides
=
(
2
,
2
),
training
=
None
):
"""A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the first conv layer in the block.
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
Note that from stage 3,
the first conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1
,
filters2
=
filters
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters1
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
bn_name_base
+
'2a'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
x
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
Activation
(
'relu'
)(
x
)
x
=
tf
.
keras
.
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
bn_name_base
+
'2b'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
x
,
training
=
training
)
shortcut
=
tf
.
keras
.
layers
.
Conv2D
(
filters2
,
(
1
,
1
),
strides
=
strides
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
shortcut
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
bn_name_base
+
'1'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
shortcut
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
add
([
x
,
shortcut
])
x
=
tf
.
keras
.
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
resnet56
(
classes
=
100
,
training
=
None
):
"""Instantiates the ResNet56 architecture.
Arguments:
classes: optional number of classes to classify images into
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
A Keras model instance.
"""
# Determine proper input shape
if
backend
.
image_data_format
()
==
'channels_first'
:
input_shape
=
(
3
,
32
,
32
)
bn_axis
=
1
else
:
# channel_last
input_shape
=
(
32
,
32
,
3
)
bn_axis
=
3
img_input
=
layers
.
Input
(
shape
=
input_shape
)
x
=
tf
.
keras
.
layers
.
ZeroPadding2D
(
padding
=
(
1
,
1
),
name
=
'conv1_pad'
)(
img_input
)
x
=
tf
.
keras
.
layers
.
Conv2D
(
16
,
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'valid'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'conv1'
)(
x
)
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
axis
=
bn_axis
,
name
=
'bn_conv1'
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
)(
x
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
Activation
(
'relu'
)(
x
)
x
=
conv_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
),
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'b'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'c'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'd'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'e'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'f'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'g'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'h'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
16
,
16
],
stage
=
2
,
block
=
'i'
,
training
=
training
)
x
=
conv_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'a'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'b'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'c'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'd'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'e'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'f'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'g'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'h'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
32
,
32
],
stage
=
3
,
block
=
'i'
,
training
=
training
)
x
=
conv_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'a'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'b'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'c'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'd'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'e'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'f'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'g'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'h'
,
training
=
training
)
x
=
identity_building_block
(
x
,
3
,
[
64
,
64
],
stage
=
4
,
block
=
'i'
,
training
=
training
)
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
'avg_pool'
)(
x
)
x
=
tf
.
keras
.
layers
.
Dense
(
classes
,
activation
=
'softmax'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'fc10'
)(
x
)
inputs
=
img_input
# Create model.
model
=
tf
.
keras
.
models
.
Model
(
inputs
,
x
,
name
=
'resnet56'
)
return
model
official/resnet/keras/resnet_model.py
0 → 100644
View file @
842e5a3e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet50 model for Keras.
Adapted from tf.keras.applications.resnet50.ResNet50().
This is ResNet model version 1.5.
Related papers/blogs:
- https://arxiv.org/abs/1512.03385
- https://arxiv.org/pdf/1603.05027v2.pdf
- http://torch.ch/blog/2016/02/04/resnets.html
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
warnings
from
tensorflow.python.keras
import
backend
from
tensorflow.python.keras
import
layers
from
tensorflow.python.keras
import
models
from
tensorflow.python.keras
import
regularizers
from
tensorflow.python.keras
import
utils
L2_WEIGHT_DECAY
=
1e-4
BATCH_NORM_DECAY
=
0.9
BATCH_NORM_EPSILON
=
1e-5
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
x
=
layers
.
add
([
x
,
input_tensor
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
conv_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
strides
=
(
2
,
2
)):
"""A block that has a conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
# Returns
Output tensor for the block.
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1
,
filters2
,
filters3
=
filters
if
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
3
else
:
bn_axis
=
1
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'1'
)(
shortcut
)
x
=
layers
.
add
([
x
,
shortcut
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
resnet50
(
num_classes
):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
Returns:
A Keras model instance.
"""
# Determine proper input shape
if
backend
.
image_data_format
()
==
'channels_first'
:
input_shape
=
(
3
,
224
,
224
)
bn_axis
=
1
else
:
input_shape
=
(
224
,
224
,
3
)
bn_axis
=
3
img_input
=
layers
.
Input
(
shape
=
input_shape
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
img_input
)
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
strides
=
(
2
,
2
),
padding
=
'valid'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'conv1'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
'bn_conv1'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
1
,
1
),
name
=
'pool1_pad'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
))(
x
)
x
=
conv_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
))
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
)
x
=
layers
.
GlobalAveragePooling2D
(
name
=
'avg_pool'
)(
x
)
x
=
layers
.
Dense
(
num_classes
,
activation
=
'softmax'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'fc1000'
)(
x
)
# Create model.
return
models
.
Model
(
img_input
,
x
,
name
=
'resnet50'
)
official/resnet/resnet_run_loop.py
View file @
842e5a3e
...
...
@@ -628,7 +628,10 @@ def define_resnet_flags(resnet_size_choices=None):
'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.'
))
flags
.
DEFINE_boolean
(
name
=
'turn_off_distribution_strategy'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Set to True to not use distribution '
'strategies.'
))
choice_kwargs
=
dict
(
name
=
'resnet_size'
,
short_name
=
'rs'
,
default
=
'50'
,
help
=
flags_core
.
help_wrap
(
'The size of the ResNet model to use.'
))
...
...
official/utils/misc/distribution_utils.py
View file @
842e5a3e
...
...
@@ -21,7 +21,9 @@ from __future__ import print_function
import
tensorflow
as
tf
def
get_distribution_strategy
(
num_gpus
,
all_reduce_alg
=
None
):
def
get_distribution_strategy
(
num_gpus
,
all_reduce_alg
=
None
,
turn_off_distribution_strategy
=
False
):
"""Return a DistributionStrategy for running the model.
Args:
...
...
@@ -30,15 +32,31 @@ def get_distribution_strategy(num_gpus, all_reduce_alg=None):
See tf.contrib.distribute.AllReduceCrossDeviceOps for available
algorithms. If None, DistributionStrategy will choose based on device
topology.
turn_off_distribution_strategy: when set to True, do not use any
distribution strategy. Note that when it is True, and num_gpus is
larger than 1, it will raise a ValueError.
Returns:
tf.contrib.distribute.DistibutionStrategy object.
Raises:
ValueError: if turn_off_distribution_strategy is True and num_gpus is
larger than 1
"""
if
num_gpus
==
0
:
return
tf
.
contrib
.
distribute
.
OneDeviceStrategy
(
"device:CPU:0"
)
if
turn_off_distribution_strategy
:
return
None
else
:
return
tf
.
contrib
.
distribute
.
OneDeviceStrategy
(
"device:CPU:0"
)
elif
num_gpus
==
1
:
return
tf
.
contrib
.
distribute
.
OneDeviceStrategy
(
"device:GPU:0"
)
else
:
if
turn_off_distribution_strategy
:
return
None
else
:
return
tf
.
contrib
.
distribute
.
OneDeviceStrategy
(
"device:GPU:0"
)
elif
turn_off_distribution_strategy
:
raise
ValueError
(
"When {} GPUs are specified, "
"turn_off_distribution_strategy flag cannot be set to"
"True."
.
format
(
num_gpus
))
else
:
# num_gpus > 1 and not turn_off_distribution_strategy
if
all_reduce_alg
:
return
tf
.
contrib
.
distribute
.
MirroredStrategy
(
num_gpus
=
num_gpus
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment