Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
424c2045
Commit
424c2045
authored
Dec 19, 2018
by
Shining Sun
Browse files
Before all the data related change
parent
53ff5d90
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
104 deletions
+77
-104
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+13
-13
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+11
-11
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+22
-22
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+12
-35
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+18
-22
official/resnet/keras/resnet56.py
official/resnet/keras/resnet56.py
+1
-1
No files found.
official/resnet/cifar10_main.py
View file @
424c2045
...
@@ -29,13 +29,13 @@ from official.utils.logs import logger
...
@@ -29,13 +29,13 @@ from official.utils.logs import logger
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
from
official.resnet
import
resnet_run_loop
_
HEIGHT
=
32
HEIGHT
=
32
_
WIDTH
=
32
WIDTH
=
32
_
NUM_CHANNELS
=
3
NUM_CHANNELS
=
3
_DEFAULT_IMAGE_BYTES
=
_
HEIGHT
*
_
WIDTH
*
_
NUM_CHANNELS
_DEFAULT_IMAGE_BYTES
=
HEIGHT
*
WIDTH
*
NUM_CHANNELS
# The record is the image plus a one-byte label
# The record is the image plus a one-byte label
_RECORD_BYTES
=
_DEFAULT_IMAGE_BYTES
+
1
_RECORD_BYTES
=
_DEFAULT_IMAGE_BYTES
+
1
_
NUM_CLASSES
=
10
NUM_CLASSES
=
10
_NUM_DATA_FILES
=
5
_NUM_DATA_FILES
=
5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
...
@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
...
@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
# The remaining bytes after the label represent the image, which we reshape
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
# from [depth * height * width] to [depth, height, width].
depth_major
=
tf
.
reshape
(
record_vector
[
1
:
_RECORD_BYTES
],
depth_major
=
tf
.
reshape
(
record_vector
[
1
:
_RECORD_BYTES
],
[
_
NUM_CHANNELS
,
_
HEIGHT
,
_
WIDTH
])
[
NUM_CHANNELS
,
HEIGHT
,
WIDTH
])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
# float32.
...
@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
...
@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
if
is_training
:
if
is_training
:
# Resize the image to add four extra pixels on each side.
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_
HEIGHT
+
8
,
_
WIDTH
+
8
)
image
,
HEIGHT
+
8
,
WIDTH
+
8
)
# Randomly crop a [
_
HEIGHT,
_
WIDTH] section of the image.
# Randomly crop a [HEIGHT, WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Randomly flip the image horizontally.
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
image
=
tf
.
image
.
random_flip_left_right
(
image
)
...
@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
...
@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
return
resnet_run_loop
.
get_synth_input_fn
(
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
dtype
=
dtype
)
HEIGHT
,
WIDTH
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
###############################################################################
...
@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
...
@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
class
Cifar10Model
(
resnet_model
.
Model
):
class
Cifar10Model
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for CIFAR-10 data."""
"""Model class with appropriate defaults for CIFAR-10 data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for CIFAR-10 data.
"""These are the parameters that work for CIFAR-10 data.
...
@@ -196,7 +196,7 @@ class Cifar10Model(resnet_model.Model):
...
@@ -196,7 +196,7 @@ class Cifar10Model(resnet_model.Model):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for CIFAR-10."""
"""Model function for CIFAR-10."""
features
=
tf
.
reshape
(
features
,
[
-
1
,
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
features
=
tf
.
reshape
(
features
,
[
-
1
,
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
...
@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
...
@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
input_fn
)
input_fn
)
resnet_run_loop
.
resnet_main
(
resnet_run_loop
.
resnet_main
(
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
shape
=
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
def
main
(
_
):
def
main
(
_
):
...
...
official/resnet/imagenet_main.py
View file @
424c2045
...
@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
...
@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
from
official.resnet
import
resnet_run_loop
_
DEFAULT_IMAGE_SIZE
=
224
DEFAULT_IMAGE_SIZE
=
224
_
NUM_CHANNELS
=
3
NUM_CHANNELS
=
3
_
NUM_CLASSES
=
1001
NUM_CLASSES
=
1001
_
NUM_IMAGES
=
{
NUM_IMAGES
=
{
'train'
:
1281167
,
'train'
:
1281167
,
'validation'
:
50000
,
'validation'
:
50000
,
}
}
...
@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
...
@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
image
=
imagenet_preprocessing
.
preprocess_image
(
image
=
imagenet_preprocessing
.
preprocess_image
(
image_buffer
=
image_buffer
,
image_buffer
=
image_buffer
,
bbox
=
bbox
,
bbox
=
bbox
,
output_height
=
_
DEFAULT_IMAGE_SIZE
,
output_height
=
DEFAULT_IMAGE_SIZE
,
output_width
=
_
DEFAULT_IMAGE_SIZE
,
output_width
=
DEFAULT_IMAGE_SIZE
,
num_channels
=
_
NUM_CHANNELS
,
num_channels
=
NUM_CHANNELS
,
is_training
=
is_training
)
is_training
=
is_training
)
image
=
tf
.
cast
(
image
,
dtype
)
image
=
tf
.
cast
(
image
,
dtype
)
...
@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
...
@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
return
resnet_run_loop
.
get_synth_input_fn
(
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
dtype
=
dtype
)
...
@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
...
@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
class
ImagenetModel
(
resnet_model
.
Model
):
class
ImagenetModel
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for Imagenet data."""
"""Model class with appropriate defaults for Imagenet data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for Imagenet data.
"""These are the parameters that work for Imagenet data.
...
@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
...
@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
num_images
=
_
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
num_images
=
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
],
warmup
=
warmup
,
base_lr
=
base_lr
)
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
],
warmup
=
warmup
,
base_lr
=
base_lr
)
return
resnet_run_loop
.
resnet_model_fn
(
return
resnet_run_loop
.
resnet_model_fn
(
...
@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
...
@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
resnet_run_loop
.
resnet_main
(
resnet_run_loop
.
resnet_main
(
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
])
shape
=
[
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
])
def
main
(
_
):
def
main
(
_
):
...
...
official/resnet/keras/keras_cifar_main.py
View file @
424c2045
...
@@ -18,11 +18,8 @@ from __future__ import absolute_import
...
@@ -18,11 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
time
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
as
cifar_main
from
official.resnet
import
cifar10_main
as
cifar_main
...
@@ -68,6 +65,8 @@ def parse_record_keras(raw_record, is_training, dtype):
...
@@ -68,6 +65,8 @@ def parse_record_keras(raw_record, is_training, dtype):
The input record is parsed into a label and image, and the image is passed
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to onhot to fit the loss function.
Args:
Args:
raw_record: scalar Tensor tf.string containing a serialized
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Example protocol buffer.
...
@@ -78,7 +77,7 @@ def parse_record_keras(raw_record, is_training, dtype):
...
@@ -78,7 +77,7 @@ def parse_record_keras(raw_record, is_training, dtype):
Tuple with processed image tensor and one-hot-encoded label tensor.
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
"""
image
,
label
=
cifar_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
image
,
label
=
cifar_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
label
=
tf
.
sparse_to_dense
(
label
,
(
cifar_main
.
_
NUM_CLASSES
,),
1
)
label
=
tf
.
sparse_to_dense
(
label
,
(
cifar_main
.
NUM_CLASSES
,),
1
)
return
image
,
label
return
image
,
label
...
@@ -105,26 +104,26 @@ def run(flags_obj):
...
@@ -105,26 +104,26 @@ def run(flags_obj):
# pylint: disable=protected-access
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
cifar_main
.
_
HEIGHT
,
cifar_main
.
_
WIDTH
,
cifar_main
.
HEIGHT
,
cifar_main
.
WIDTH
,
cifar_main
.
_
NUM_CHANNELS
,
cifar_main
.
_
NUM_CLASSES
,
cifar_main
.
NUM_CHANNELS
,
cifar_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
train_input_dataset
=
synth_input_fn
(
train_input_dataset
=
synth_input_fn
(
True
,
True
,
flags_obj
.
data_dir
,
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
batch_size
=
per_device_batch_size
,
height
=
cifar_main
.
_
HEIGHT
,
height
=
cifar_main
.
HEIGHT
,
width
=
cifar_main
.
_
WIDTH
,
width
=
cifar_main
.
WIDTH
,
num_channels
=
cifar_main
.
_
NUM_CHANNELS
,
num_channels
=
cifar_main
.
NUM_CHANNELS
,
num_classes
=
cifar_main
.
_
NUM_CLASSES
,
num_classes
=
cifar_main
.
NUM_CLASSES
,
dtype
=
dtype
)
dtype
=
dtype
)
eval_input_dataset
=
synth_input_fn
(
eval_input_dataset
=
synth_input_fn
(
False
,
False
,
flags_obj
.
data_dir
,
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
batch_size
=
per_device_batch_size
,
height
=
cifar_main
.
_
HEIGHT
,
height
=
cifar_main
.
HEIGHT
,
width
=
cifar_main
.
_
WIDTH
,
width
=
cifar_main
.
WIDTH
,
num_channels
=
cifar_main
.
_
NUM_CHANNELS
,
num_channels
=
cifar_main
.
NUM_CHANNELS
,
num_classes
=
cifar_main
.
_
NUM_CLASSES
,
num_classes
=
cifar_main
.
NUM_CLASSES
,
dtype
=
dtype
)
dtype
=
dtype
)
# pylint: enable=protected-access
# pylint: enable=protected-access
...
@@ -144,20 +143,22 @@ def run(flags_obj):
...
@@ -144,20 +143,22 @@ def run(flags_obj):
parse_record_fn
=
parse_record_keras
)
parse_record_fn
=
parse_record_keras
)
optimizer
=
keras_common
.
get_optimizer
()
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
keras_common
.
get_dist_strategy
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
use_one_device_strategy
)
model
=
resnet56
.
ResNet56
(
input_shape
=
(
32
,
32
,
3
),
model
=
resnet56
.
ResNet56
(
input_shape
=
(
32
,
32
,
3
),
classes
=
cifar_main
.
_
NUM_CLASSES
)
classes
=
cifar_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
metrics
=
[
'categorical_accuracy'
],
strategy
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_
fit_
callbacks
(
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
)
learning_rate_schedule
,
cifar_main
.
NUM_IMAGES
[
'train'
]
)
steps_per_epoch
=
cifar_main
.
_
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
steps_per_epoch
=
cifar_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
cifar_main
.
_
NUM_IMAGES
[
'validation'
]
//
num_eval_steps
=
(
cifar_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
flags_obj
.
batch_size
)
history
=
model
.
fit
(
train_input_dataset
,
history
=
model
.
fit
(
train_input_dataset
,
...
@@ -176,7 +177,6 @@ def run(flags_obj):
...
@@ -176,7 +177,6 @@ def run(flags_obj):
steps
=
num_eval_steps
,
steps
=
num_eval_steps
,
verbose
=
1
)
verbose
=
1
)
print
(
'Test loss:'
,
eval_output
[
0
])
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
return
stats
return
stats
...
@@ -188,6 +188,6 @@ def main(_):
...
@@ -188,6 +188,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
cifar_main
.
define_cifar_flags
()
cifar_main
.
define_cifar_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/resnet/keras/keras_common.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Common util functions an classes used by both keras cifar and imagenet."""
"""Common util functions an
d
classes used by both keras cifar and imagenet."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
@@ -20,13 +20,10 @@ from __future__ import print_function
...
@@ -20,13 +20,10 @@ from __future__ import print_function
import
time
import
time
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.utils.misc
import
distribution_utils
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
...
@@ -37,7 +34,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -37,7 +34,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
):
def
__init__
(
self
,
batch_size
):
"""Callback for
Keras models
.
"""Callback for
logging performance (# image/second)
.
Args:
Args:
batch_size: Total batch size.
batch_size: Total batch size.
...
@@ -45,28 +42,22 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -45,28 +42,22 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""
"""
self
.
_batch_size
=
batch_size
self
.
_batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_batch_size
=
100
def
on_train_begin
(
self
,
logs
=
None
):
def
on_train_begin
(
self
,
logs
=
None
):
self
.
epoch_times_secs
=
[]
self
.
batch_times_secs
=
[]
self
.
batch_times_secs
=
[]
self
.
record_batch
=
True
self
.
record_batch
=
True
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_time_start
=
time
.
time
()
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_times_secs
.
append
(
time
.
time
()
-
self
.
epoch_time_start
)
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
self
.
record_batch
:
if
self
.
record_batch
:
self
.
batch_time_start
=
time
.
time
()
self
.
batch_time_start
=
time
.
time
()
self
.
record_batch
=
False
self
.
record_batch
=
False
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
n
=
100
if
batch
%
self
.
log_batch_size
==
0
:
if
batch
%
n
==
0
:
last_n_batches
=
time
.
time
()
-
self
.
batch_time_start
last_n_batches
=
time
.
time
()
-
self
.
batch_time_start
examples_per_second
=
(
self
.
_batch_size
*
n
)
/
last_n_batches
examples_per_second
=
(
self
.
_batch_size
*
self
.
log_batch_size
)
/
last_n_batches
self
.
batch_times_secs
.
append
(
last_n_batches
)
self
.
batch_times_secs
.
append
(
last_n_batches
)
self
.
record_batch
=
True
self
.
record_batch
=
True
# TODO(anjalisridhar): add timestamp as well.
# TODO(anjalisridhar): add timestamp as well.
...
@@ -95,8 +86,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -95,8 +86,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self
.
prev_lr
=
-
1
self
.
prev_lr
=
-
1
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
#
if not hasattr(self.model.optimizer, 'learning_rate'):
if
not
hasattr
(
self
.
model
.
optimizer
,
'learning_rate'
):
#
raise ValueError('Optimizer must have a "learning_rate" attribute.')
raise
ValueError
(
'Optimizer must have a "learning_rate" attribute.'
)
self
.
epochs
+=
1
self
.
epochs
+=
1
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
...
@@ -120,31 +111,16 @@ def get_optimizer():
...
@@ -120,31 +111,16 @@ def get_optimizer():
return
optimizer
return
optimizer
def
get_dist_strategy
():
def
get_callbacks
(
learning_rate_schedule_fn
,
num_images
):
if
FLAGS
.
num_gpus
==
1
and
not
FLAGS
.
use_one_device_strategy
:
print
(
'Not using distribution strategies.'
)
strategy
=
None
elif
FLAGS
.
num_gpus
>
1
and
FLAGS
.
use_one_device_strategy
:
rase
ValueError
(
"When %d GPUs are specified, use_one_device_strategy'
'flag cannot be set to True."
)
else
:
strategy
=
distribution_utils
.
get_distribution_strategy
(
num_gpus
=
FLAGS
.
num_gpus
)
return
strategy
def
get_fit_callbacks
(
learning_rate_schedule_fn
):
time_callback
=
TimeHistory
(
FLAGS
.
batch_size
)
time_callback
=
TimeHistory
(
FLAGS
.
batch_size
)
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
FLAGS
.
model_dir
)
log_dir
=
FLAGS
.
model_dir
)
#update_freq="batch") # Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule_fn
,
learning_rate_schedule_fn
,
batch_size
=
FLAGS
.
batch_size
,
batch_size
=
FLAGS
.
batch_size
,
num_images
=
image
net_main
.
_NUM_IMAGES
[
'train'
]
)
num_images
=
num_
image
s
)
return
time_callback
,
tensorboard_callback
,
lr_callback
return
time_callback
,
tensorboard_callback
,
lr_callback
...
@@ -155,6 +131,7 @@ def analyze_fit_and_eval_result(history, eval_output):
...
@@ -155,6 +131,7 @@ def analyze_fit_and_eval_result(history, eval_output):
stats
[
'training_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
stats
[
'training_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
stats
[
'training_accuracy_top_1'
]
=
history
.
history
[
'categorical_accuracy'
][
-
1
]
stats
[
'training_accuracy_top_1'
]
=
history
.
history
[
'categorical_accuracy'
][
-
1
]
print
(
'Test loss:{}'
.
format
(
stats
[
''
]))
print
(
'top_1 accuracy:{}'
.
format
(
stats
[
'accuracy_top_1'
]))
print
(
'top_1 accuracy:{}'
.
format
(
stats
[
'accuracy_top_1'
]))
print
(
'top_1_training_accuracy:{}'
.
format
(
stats
[
'training_accuracy_top_1'
]))
print
(
'top_1_training_accuracy:{}'
.
format
(
stats
[
'training_accuracy_top_1'
]))
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,15 +18,11 @@ from __future__ import absolute_import
...
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
time
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.resnet
import
imagenet_main
from
official.resnet
import
imagenet_preprocessing
from
official.resnet
import
resnet_run_loop
from
official.resnet
import
resnet_run_loop
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet50
from
official.resnet.keras
import
resnet50
...
@@ -104,22 +100,22 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -104,22 +100,22 @@ def run_imagenet_with_keras(flags_obj):
# pylint: disable=protected-access
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
_
NUM_CHANNELS
,
imagenet_main
.
_
NUM_CLASSES
,
imagenet_main
.
NUM_CHANNELS
,
imagenet_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
train_input_dataset
=
synth_input_fn
(
train_input_dataset
=
synth_input_fn
(
batch_size
=
per_device_batch_size
,
batch_size
=
per_device_batch_size
,
height
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
height
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
_
NUM_CHANNELS
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
_
NUM_CLASSES
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
dtype
)
dtype
=
dtype
)
eval_input_dataset
=
synth_input_fn
(
eval_input_dataset
=
synth_input_fn
(
batch_size
=
per_device_batch_size
,
batch_size
=
per_device_batch_size
,
height
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
height
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
_
NUM_CHANNELS
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
_
NUM_CLASSES
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
dtype
)
dtype
=
dtype
)
# pylint: enable=protected-access
# pylint: enable=protected-access
...
@@ -140,20 +136,21 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -140,20 +136,21 @@ def run_imagenet_with_keras(flags_obj):
optimizer
=
keras_common
.
get_optimizer
()
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
keras_common
.
get_dist_strategy
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
use_one_device_strategy
)
model
=
resnet50
.
ResNet50
(
num_classes
=
imagenet_main
.
_
NUM_CLASSES
)
model
=
resnet50
.
ResNet50
(
num_classes
=
imagenet_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
metrics
=
[
'categorical_accuracy'
],
distribute
=
strategy
)
distribute
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_
fit_
callbacks
(
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
)
learning_rate_schedule
,
imagenet_main
.
NUM_IMAGES
[
'train'
]
)
steps_per_epoch
=
imagenet_main
.
_
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
steps_per_epoch
=
imagenet_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
imagenet_main
.
_
NUM_IMAGES
[
'validation'
]
//
num_eval_steps
=
(
imagenet_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
flags_obj
.
batch_size
)
history
=
model
.
fit
(
train_input_dataset
,
history
=
model
.
fit
(
train_input_dataset
,
...
@@ -172,7 +169,6 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -172,7 +169,6 @@ def run_imagenet_with_keras(flags_obj):
steps
=
num_eval_steps
,
steps
=
num_eval_steps
,
verbose
=
1
)
verbose
=
1
)
print
(
'Test loss:'
,
eval_output
[
0
])
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
return
stats
return
stats
...
...
official/resnet/keras/resnet56.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment