Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
37c12026
Commit
37c12026
authored
Nov 01, 2019
by
Will Cromar
Committed by
A. Unique TensorFlower
Nov 01, 2019
Browse files
Add 2.x version of MNIST model to model garden.
PiperOrigin-RevId: 277946653
parent
24c619ff
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
258 additions
and
0 deletions
+258
-0
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+169
-0
official/vision/image_classification/mnist_test.py
official/vision/image_classification/mnist_test.py
+89
-0
No files found.
official/vision/image_classification/mnist_main.py
0 → 100644
View file @
37c12026
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a simple model on the MNIST dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification
import
common
FLAGS
=
flags
.
FLAGS
def
build_model
():
"""Constructs the ML model used to predict handwritten digits."""
image
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
28
,
28
,
1
))
y
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
32
,
kernel_size
=
5
,
padding
=
'same'
,
activation
=
'relu'
)(
image
)
y
=
tf
.
keras
.
layers
.
MaxPooling2D
(
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
y
)
y
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
32
,
kernel_size
=
5
,
padding
=
'same'
,
activation
=
'relu'
)(
y
)
y
=
tf
.
keras
.
layers
.
MaxPooling2D
(
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
y
)
y
=
tf
.
keras
.
layers
.
Flatten
()(
y
)
y
=
tf
.
keras
.
layers
.
Dense
(
1024
,
activation
=
'relu'
)(
y
)
y
=
tf
.
keras
.
layers
.
Dropout
(
0.4
)(
y
)
probs
=
tf
.
keras
.
layers
.
Dense
(
10
,
activation
=
'softmax'
)(
y
)
model
=
tf
.
keras
.
models
.
Model
(
image
,
probs
,
name
=
'mnist'
)
return
model
@
tfds
.
decode
.
make_decoder
(
output_dtype
=
tf
.
float32
)
def
decode_image
(
example
,
feature
):
"""Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
return
tf
.
cast
(
feature
.
decode_example
(
example
),
dtype
=
tf
.
float32
)
/
255
def
run
(
flags_obj
,
strategy_override
=
None
):
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training and eval stats.
"""
strategy
=
strategy_override
or
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
tpu_address
=
flags_obj
.
tpu
)
strategy_scope
=
distribution_utils
.
get_strategy_scope
(
strategy
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
if
flags_obj
.
download
:
mnist
.
download_and_prepare
()
mnist_train
,
mnist_test
=
mnist
.
as_dataset
(
split
=
[
'train'
,
'test'
],
decoders
=
{
'image'
:
decode_image
()},
# pylint: disable=no-value-for-parameter
as_supervised
=
True
)
train_input_dataset
=
mnist_train
.
cache
().
repeat
().
shuffle
(
buffer_size
=
50000
).
batch
(
flags_obj
.
batch_size
)
eval_input_dataset
=
mnist_test
.
cache
().
repeat
().
batch
(
flags_obj
.
batch_size
)
with
strategy_scope
:
lr_schedule
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
0.05
,
decay_steps
=
100000
,
decay_rate
=
0.96
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
lr_schedule
)
model
=
build_model
()
model
.
compile
(
optimizer
=
optimizer
,
loss
=
'sparse_categorical_crossentropy'
,
metrics
=
[
'sparse_categorical_accuracy'
])
num_train_examples
=
mnist
.
info
.
splits
[
'train'
].
num_examples
train_steps
=
num_train_examples
//
flags_obj
.
batch_size
train_epochs
=
flags_obj
.
train_epochs
ckpt_full_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
=
[
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
),
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
flags_obj
.
model_dir
),
]
num_eval_examples
=
mnist
.
info
.
splits
[
'test'
].
num_examples
num_eval_steps
=
num_eval_examples
//
flags_obj
.
batch_size
history
=
model
.
fit
(
train_input_dataset
,
epochs
=
train_epochs
,
steps_per_epoch
=
train_steps
,
callbacks
=
callbacks
,
validation_steps
=
num_eval_steps
,
validation_data
=
eval_input_dataset
,
validation_freq
=
flags_obj
.
epochs_between_evals
)
export_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'saved_model'
)
model
.
save
(
export_path
,
include_optimizer
=
False
)
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
steps
=
num_eval_steps
,
verbose
=
2
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
def
define_mnist_flags
():
"""Define command line flags for MNIST model."""
flags_core
.
define_base
(
clean
=
True
,
num_gpu
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
distribution_strategy
=
True
)
flags_core
.
define_device
()
flags_core
.
define_distribution
()
flags
.
DEFINE_bool
(
'download'
,
False
,
'Whether to download data to `--data_dir`.'
)
FLAGS
.
set_default
(
'batch_size'
,
1024
)
def
main
(
_
):
model_helpers
.
apply_clean
(
FLAGS
)
stats
=
run
(
flags
.
FLAGS
)
logging
.
info
(
'Run stats:
\n
%s'
,
stats
)
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
define_mnist_flags
()
app
.
run
(
main
)
official/vision/image_classification/mnist_test.py
0 → 100644
View file @
37c12026
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the Keras MNIST model on GPU."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
mnist_main
def
eager_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
KerasMnistTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
"""Unit tests for sample Keras MNIST model."""
_tempdir
=
None
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasMnistTest
,
cls
).
setUpClass
()
mnist_main
.
define_mnist_flags
()
def
tearDown
(
self
):
super
(
KerasMnistTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
@
combinations
.
generate
(
eager_strategy_combinations
())
def
test_end_to_end
(
self
,
distribution
):
"""Test Keras MNIST model with `strategy`."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-train_epochs"
,
"1"
,
# Let TFDS find the metadata folder automatically
"--data_dir="
]
def
_mock_dataset
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
"""Generate mock dataset with TPU-compatible dtype (instead of uint8)."""
return
tf
.
data
.
Dataset
.
from_tensor_slices
({
"image"
:
tf
.
ones
(
shape
=
(
10
,
28
,
28
,
1
),
dtype
=
tf
.
int32
),
"label"
:
tf
.
range
(
10
),
})
run
=
functools
.
partial
(
mnist_main
.
run
,
strategy_override
=
distribution
)
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_mock_dataset
):
integration
.
run_synthetic
(
main
=
run
,
synth
=
False
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment