Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e48a403e
"vscode:/vscode.git/clone" did not exist on "3488f3142c734326b54badfba4b166173647d1b2"
Commit
e48a403e
authored
Apr 10, 2018
by
Asim Shankar
Browse files
official/mnist: Use tf.keras.Sequential to simplify network definition.
parent
aad56e4c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
54 deletions
+35
-54
official/mnist/mnist.py
official/mnist/mnist.py
+31
-50
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+1
-1
official/mnist/mnist_eager_test.py
official/mnist/mnist_eager_test.py
+2
-2
official/mnist/mnist_tpu.py
official/mnist/mnist_tpu.py
+1
-1
No files found.
official/mnist/mnist.py
View file @
e48a403e
...
...
@@ -29,7 +29,7 @@ from official.utils.logs import hooks_helper
LEARNING_RATE
=
1e-4
class
Model
(
tf
.
keras
.
Model
):
def
create_model
(
data_format
):
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
...
...
@@ -37,60 +37,41 @@ class Model(tf.keras.Model):
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written as a tf.keras.Model using the tf.layers API.
"""
def
__init__
(
self
,
data_format
):
"""Creates a model for classifying a hand-written digit.
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
"""
super
(
Model
,
self
).
__init__
()
input_shape
=
None
if
data_format
==
'channels_first'
:
self
.
_
input_shape
=
[
-
1
,
1
,
28
,
28
]
input_shape
=
[
1
,
28
,
28
]
else
:
assert
data_format
==
'channels_last'
self
.
_input_shape
=
[
-
1
,
28
,
28
,
1
]
self
.
conv1
=
tf
.
layers
.
Conv2D
(
32
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
)
self
.
conv2
=
tf
.
layers
.
Conv2D
(
64
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
)
self
.
fc1
=
tf
.
layers
.
Dense
(
1024
,
activation
=
tf
.
nn
.
relu
)
self
.
fc2
=
tf
.
layers
.
Dense
(
10
)
self
.
dropout
=
tf
.
layers
.
Dropout
(
0.4
)
self
.
max_pool2d
=
tf
.
layers
.
MaxPooling2D
(
(
2
,
2
),
(
2
,
2
),
padding
=
'same'
,
data_format
=
data_format
)
def
__call__
(
self
,
inputs
,
training
):
"""Add operations to classify a batch of input images.
Args:
inputs: A Tensor representing a batch of input images.
training: A boolean. Set to True to add operations required only when
training the classifier.
Returns:
A logits Tensor with shape [<batch_size>, 10].
"""
y
=
tf
.
reshape
(
inputs
,
self
.
_input_shape
)
y
=
self
.
conv1
(
y
)
y
=
self
.
max_pool2d
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
max_pool2d
(
y
)
y
=
tf
.
layers
.
flatten
(
y
)
y
=
self
.
fc1
(
y
)
y
=
self
.
dropout
(
y
,
training
=
training
)
return
self
.
fc2
(
y
)
input_shape
=
[
28
,
28
,
1
]
L
=
tf
.
keras
.
layers
max_pool
=
L
.
MaxPooling2D
((
2
,
2
),
(
2
,
2
),
padding
=
'same'
,
data_format
=
data_format
)
return
tf
.
keras
.
Sequential
([
L
.
Reshape
(
input_shape
),
L
.
Conv2D
(
32
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
),
max_pool
,
L
.
Conv2D
(
64
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
),
max_pool
,
L
.
Flatten
(),
L
.
Dense
(
1024
,
activation
=
tf
.
nn
.
relu
),
L
.
Dropout
(
0.4
),
L
.
Dense
(
10
)])
def
model_fn
(
features
,
labels
,
mode
,
params
):
"""The model_fn argument for creating an Estimator."""
model
=
M
odel
(
params
[
'data_format'
])
model
=
create_m
odel
(
params
[
'data_format'
])
image
=
features
if
isinstance
(
image
,
dict
):
image
=
features
[
'image'
]
...
...
official/mnist/mnist_eager.py
View file @
e48a403e
...
...
@@ -116,7 +116,7 @@ def main(argv):
test_ds
=
mnist_dataset
.
test
(
flags
.
data_dir
).
batch
(
flags
.
batch_size
)
# Create the model and optimizer
model
=
mnist
.
M
odel
(
data_format
)
model
=
mnist
.
create_m
odel
(
data_format
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
flags
.
lr
,
flags
.
momentum
)
# Create file writers for writing TensorBoard summaries.
...
...
official/mnist/mnist_eager_test.py
View file @
e48a403e
...
...
@@ -40,7 +40,7 @@ def random_dataset():
def
train
(
defun
=
False
):
model
=
mnist
.
M
odel
(
data_format
())
model
=
mnist
.
create_m
odel
(
data_format
())
if
defun
:
model
.
call
=
tfe
.
defun
(
model
.
call
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
...
...
@@ -51,7 +51,7 @@ def train(defun=False):
def
evaluate
(
defun
=
False
):
model
=
mnist
.
M
odel
(
data_format
())
model
=
mnist
.
create_m
odel
(
data_format
())
dataset
=
random_dataset
()
if
defun
:
model
.
call
=
tfe
.
defun
(
model
.
call
)
...
...
official/mnist/mnist_tpu.py
View file @
e48a403e
...
...
@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params):
if
isinstance
(
image
,
dict
):
image
=
features
[
"image"
]
model
=
mnist
.
M
odel
(
"channels_last"
)
model
=
mnist
.
create_m
odel
(
"channels_last"
)
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment