Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
424c2045
Commit
424c2045
authored
Dec 19, 2018
by
Shining Sun
Browse files
Before all the data related change
parent
53ff5d90
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
104 deletions
+77
-104
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+13
-13
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+11
-11
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+22
-22
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+12
-35
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+18
-22
official/resnet/keras/resnet56.py
official/resnet/keras/resnet56.py
+1
-1
No files found.
official/resnet/cifar10_main.py
View file @
424c2045
...
...
@@ -29,13 +29,13 @@ from official.utils.logs import logger
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
_
HEIGHT
=
32
_
WIDTH
=
32
_
NUM_CHANNELS
=
3
_DEFAULT_IMAGE_BYTES
=
_
HEIGHT
*
_
WIDTH
*
_
NUM_CHANNELS
HEIGHT
=
32
WIDTH
=
32
NUM_CHANNELS
=
3
_DEFAULT_IMAGE_BYTES
=
HEIGHT
*
WIDTH
*
NUM_CHANNELS
# The record is the image plus a one-byte label
_RECORD_BYTES
=
_DEFAULT_IMAGE_BYTES
+
1
_
NUM_CLASSES
=
10
NUM_CLASSES
=
10
_NUM_DATA_FILES
=
5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
...
...
@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major
=
tf
.
reshape
(
record_vector
[
1
:
_RECORD_BYTES
],
[
_
NUM_CHANNELS
,
_
HEIGHT
,
_
WIDTH
])
[
NUM_CHANNELS
,
HEIGHT
,
WIDTH
])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
...
...
@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
if
is_training
:
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_
HEIGHT
+
8
,
_
WIDTH
+
8
)
image
,
HEIGHT
+
8
,
WIDTH
+
8
)
# Randomly crop a [
_
HEIGHT,
_
WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
# Randomly crop a [HEIGHT, WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
...
...
@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
dtype
=
dtype
)
HEIGHT
,
WIDTH
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
...
...
@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
class
Cifar10Model
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for CIFAR-10 data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for CIFAR-10 data.
...
...
@@ -196,7 +196,7 @@ class Cifar10Model(resnet_model.Model):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for CIFAR-10."""
features
=
tf
.
reshape
(
features
,
[
-
1
,
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
features
=
tf
.
reshape
(
features
,
[
-
1
,
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
...
...
@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
input_fn
)
resnet_run_loop
.
resnet_main
(
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
HEIGHT
,
_
WIDTH
,
_
NUM_CHANNELS
])
shape
=
[
HEIGHT
,
WIDTH
,
NUM_CHANNELS
])
def
main
(
_
):
...
...
official/resnet/imagenet_main.py
View file @
424c2045
...
...
@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
_
DEFAULT_IMAGE_SIZE
=
224
_
NUM_CHANNELS
=
3
_
NUM_CLASSES
=
1001
DEFAULT_IMAGE_SIZE
=
224
NUM_CHANNELS
=
3
NUM_CLASSES
=
1001
_
NUM_IMAGES
=
{
NUM_IMAGES
=
{
'train'
:
1281167
,
'validation'
:
50000
,
}
...
...
@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
image
=
imagenet_preprocessing
.
preprocess_image
(
image_buffer
=
image_buffer
,
bbox
=
bbox
,
output_height
=
_
DEFAULT_IMAGE_SIZE
,
output_width
=
_
DEFAULT_IMAGE_SIZE
,
num_channels
=
_
NUM_CHANNELS
,
output_height
=
DEFAULT_IMAGE_SIZE
,
output_width
=
DEFAULT_IMAGE_SIZE
,
num_channels
=
NUM_CHANNELS
,
is_training
=
is_training
)
image
=
tf
.
cast
(
image
,
dtype
)
...
...
@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
,
_
NUM_CLASSES
,
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
,
NUM_CLASSES
,
dtype
=
dtype
)
...
...
@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
class
ImagenetModel
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for Imagenet data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_
NUM_CLASSES
,
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
NUM_CLASSES
,
resnet_version
=
resnet_model
.
DEFAULT_VERSION
,
dtype
=
resnet_model
.
DEFAULT_DTYPE
):
"""These are the parameters that work for Imagenet data.
...
...
@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
num_images
=
_
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
num_images
=
NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
],
warmup
=
warmup
,
base_lr
=
base_lr
)
return
resnet_run_loop
.
resnet_model_fn
(
...
...
@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
resnet_run_loop
.
resnet_main
(
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_
DEFAULT_IMAGE_SIZE
,
_
DEFAULT_IMAGE_SIZE
,
_
NUM_CHANNELS
])
shape
=
[
DEFAULT_IMAGE_SIZE
,
DEFAULT_IMAGE_SIZE
,
NUM_CHANNELS
])
def
main
(
_
):
...
...
official/resnet/keras/keras_cifar_main.py
View file @
424c2045
...
...
@@ -18,11 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
as
cifar_main
...
...
@@ -68,6 +65,8 @@ def parse_record_keras(raw_record, is_training, dtype):
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to onhot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
...
...
@@ -78,7 +77,7 @@ def parse_record_keras(raw_record, is_training, dtype):
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image
,
label
=
cifar_main
.
parse_record
(
raw_record
,
is_training
,
dtype
)
label
=
tf
.
sparse_to_dense
(
label
,
(
cifar_main
.
_
NUM_CLASSES
,),
1
)
label
=
tf
.
sparse_to_dense
(
label
,
(
cifar_main
.
NUM_CLASSES
,),
1
)
return
image
,
label
...
...
@@ -105,26 +104,26 @@ def run(flags_obj):
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
cifar_main
.
_
HEIGHT
,
cifar_main
.
_
WIDTH
,
cifar_main
.
_
NUM_CHANNELS
,
cifar_main
.
_
NUM_CLASSES
,
cifar_main
.
HEIGHT
,
cifar_main
.
WIDTH
,
cifar_main
.
NUM_CHANNELS
,
cifar_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
train_input_dataset
=
synth_input_fn
(
True
,
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
height
=
cifar_main
.
_
HEIGHT
,
width
=
cifar_main
.
_
WIDTH
,
num_channels
=
cifar_main
.
_
NUM_CHANNELS
,
num_classes
=
cifar_main
.
_
NUM_CLASSES
,
height
=
cifar_main
.
HEIGHT
,
width
=
cifar_main
.
WIDTH
,
num_channels
=
cifar_main
.
NUM_CHANNELS
,
num_classes
=
cifar_main
.
NUM_CLASSES
,
dtype
=
dtype
)
eval_input_dataset
=
synth_input_fn
(
False
,
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
,
height
=
cifar_main
.
_
HEIGHT
,
width
=
cifar_main
.
_
WIDTH
,
num_channels
=
cifar_main
.
_
NUM_CHANNELS
,
num_classes
=
cifar_main
.
_
NUM_CLASSES
,
height
=
cifar_main
.
HEIGHT
,
width
=
cifar_main
.
WIDTH
,
num_channels
=
cifar_main
.
NUM_CHANNELS
,
num_classes
=
cifar_main
.
NUM_CLASSES
,
dtype
=
dtype
)
# pylint: enable=protected-access
...
...
@@ -144,20 +143,22 @@ def run(flags_obj):
parse_record_fn
=
parse_record_keras
)
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
keras_common
.
get_dist_strategy
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
use_one_device_strategy
)
model
=
resnet56
.
ResNet56
(
input_shape
=
(
32
,
32
,
3
),
classes
=
cifar_main
.
_
NUM_CLASSES
)
classes
=
cifar_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
strategy
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_
fit_
callbacks
(
learning_rate_schedule
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
,
cifar_main
.
NUM_IMAGES
[
'train'
]
)
steps_per_epoch
=
cifar_main
.
_
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
cifar_main
.
_
NUM_IMAGES
[
'validation'
]
//
steps_per_epoch
=
cifar_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
cifar_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
history
=
model
.
fit
(
train_input_dataset
,
...
...
@@ -176,7 +177,6 @@ def run(flags_obj):
steps
=
num_eval_steps
,
verbose
=
1
)
print
(
'Test loss:'
,
eval_output
[
0
])
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
return
stats
...
...
@@ -188,6 +188,6 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
cifar_main
.
define_cifar_flags
()
absl_app
.
run
(
main
)
official/resnet/keras/keras_common.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions an classes used by both keras cifar and imagenet."""
"""Common util functions an
d
classes used by both keras cifar and imagenet."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -20,13 +20,10 @@ from __future__ import print_function
import
time
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.utils.misc
import
distribution_utils
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
...
...
@@ -37,7 +34,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
):
"""Callback for
Keras models
.
"""Callback for
logging performance (# image/second)
.
Args:
batch_size: Total batch size.
...
...
@@ -45,28 +42,22 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""
self
.
_batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_batch_size
=
100
def
on_train_begin
(
self
,
logs
=
None
):
self
.
epoch_times_secs
=
[]
self
.
batch_times_secs
=
[]
self
.
record_batch
=
True
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_time_start
=
time
.
time
()
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_times_secs
.
append
(
time
.
time
()
-
self
.
epoch_time_start
)
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
self
.
record_batch
:
self
.
batch_time_start
=
time
.
time
()
self
.
record_batch
=
False
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
n
=
100
if
batch
%
n
==
0
:
if
batch
%
self
.
log_batch_size
==
0
:
last_n_batches
=
time
.
time
()
-
self
.
batch_time_start
examples_per_second
=
(
self
.
_batch_size
*
n
)
/
last_n_batches
examples_per_second
=
(
self
.
_batch_size
*
self
.
log_batch_size
)
/
last_n_batches
self
.
batch_times_secs
.
append
(
last_n_batches
)
self
.
record_batch
=
True
# TODO(anjalisridhar): add timestamp as well.
...
...
@@ -95,8 +86,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self
.
prev_lr
=
-
1
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
#
if not hasattr(self.model.optimizer, 'learning_rate'):
#
raise ValueError('Optimizer must have a "learning_rate" attribute.')
if
not
hasattr
(
self
.
model
.
optimizer
,
'learning_rate'
):
raise
ValueError
(
'Optimizer must have a "learning_rate" attribute.'
)
self
.
epochs
+=
1
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
...
...
@@ -120,31 +111,16 @@ def get_optimizer():
return
optimizer
def
get_dist_strategy
():
if
FLAGS
.
num_gpus
==
1
and
not
FLAGS
.
use_one_device_strategy
:
print
(
'Not using distribution strategies.'
)
strategy
=
None
elif
FLAGS
.
num_gpus
>
1
and
FLAGS
.
use_one_device_strategy
:
rase
ValueError
(
"When %d GPUs are specified, use_one_device_strategy'
'flag cannot be set to True."
)
else
:
strategy
=
distribution_utils
.
get_distribution_strategy
(
num_gpus
=
FLAGS
.
num_gpus
)
return
strategy
def
get_fit_callbacks
(
learning_rate_schedule_fn
):
def
get_callbacks
(
learning_rate_schedule_fn
,
num_images
):
time_callback
=
TimeHistory
(
FLAGS
.
batch_size
)
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
FLAGS
.
model_dir
)
#update_freq="batch") # Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule_fn
,
batch_size
=
FLAGS
.
batch_size
,
num_images
=
image
net_main
.
_NUM_IMAGES
[
'train'
]
)
num_images
=
num_
image
s
)
return
time_callback
,
tensorboard_callback
,
lr_callback
...
...
@@ -155,6 +131,7 @@ def analyze_fit_and_eval_result(history, eval_output):
stats
[
'training_loss'
]
=
history
.
history
[
'loss'
][
-
1
]
stats
[
'training_accuracy_top_1'
]
=
history
.
history
[
'categorical_accuracy'
][
-
1
]
print
(
'Test loss:{}'
.
format
(
stats
[
''
]))
print
(
'top_1 accuracy:{}'
.
format
(
stats
[
'accuracy_top_1'
]))
print
(
'top_1_training_accuracy:{}'
.
format
(
stats
[
'training_accuracy_top_1'
]))
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.resnet
import
imagenet_preprocessing
from
official.resnet
import
resnet_run_loop
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet50
...
...
@@ -104,22 +100,22 @@ def run_imagenet_with_keras(flags_obj):
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
synth_input_fn
=
resnet_run_loop
.
get_synth_input_fn
(
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
_
NUM_CHANNELS
,
imagenet_main
.
_
NUM_CLASSES
,
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
imagenet_main
.
NUM_CHANNELS
,
imagenet_main
.
NUM_CLASSES
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
train_input_dataset
=
synth_input_fn
(
batch_size
=
per_device_batch_size
,
height
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
_
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
_
NUM_CLASSES
,
height
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
dtype
)
eval_input_dataset
=
synth_input_fn
(
batch_size
=
per_device_batch_size
,
height
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
_
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
_
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
_
NUM_CLASSES
,
height
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_main
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_main
.
NUM_CHANNELS
,
num_classes
=
imagenet_main
.
NUM_CLASSES
,
dtype
=
dtype
)
# pylint: enable=protected-access
...
...
@@ -140,20 +136,21 @@ def run_imagenet_with_keras(flags_obj):
optimizer
=
keras_common
.
get_optimizer
()
strategy
=
keras_common
.
get_dist_strategy
()
strategy
=
distribution_utils
.
get_distribution_strategy
(
flags_obj
.
num_gpus
,
flags_obj
.
use_one_device_strategy
)
model
=
resnet50
.
ResNet50
(
num_classes
=
imagenet_main
.
_
NUM_CLASSES
)
model
=
resnet50
.
ResNet50
(
num_classes
=
imagenet_main
.
NUM_CLASSES
)
model
.
compile
(
loss
=
'categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
[
'categorical_accuracy'
],
distribute
=
strategy
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_
fit_
callbacks
(
learning_rate_schedule
)
time_callback
,
tensorboard_callback
,
lr_callback
=
keras_common
.
get_callbacks
(
learning_rate_schedule
,
imagenet_main
.
NUM_IMAGES
[
'train'
]
)
steps_per_epoch
=
imagenet_main
.
_
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
imagenet_main
.
_
NUM_IMAGES
[
'validation'
]
//
steps_per_epoch
=
imagenet_main
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
num_eval_steps
=
(
imagenet_main
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
history
=
model
.
fit
(
train_input_dataset
,
...
...
@@ -172,7 +169,6 @@ def run_imagenet_with_keras(flags_obj):
steps
=
num_eval_steps
,
verbose
=
1
)
print
(
'Test loss:'
,
eval_output
[
0
])
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
return
stats
...
...
official/resnet/keras/resnet56.py
View file @
424c2045
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment