Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d1718519
Commit
d1718519
authored
Apr 10, 2018
by
Asim Shankar
Browse files
official/mnist: Linter fixes
parent
4e0ca759
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
20 deletions
+29
-20
official/mnist/mnist.py
official/mnist/mnist.py
+29
-20
No files found.
official/mnist/mnist.py
View file @
d1718519
...
@@ -41,33 +41,43 @@ def create_model(data_format):
...
@@ -41,33 +41,43 @@ def create_model(data_format):
But uses the tf.keras API.
But uses the tf.keras API.
Args:
Args:
data_format: Either 'channels_first' or 'channels_last'.
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is
'channels_first' is
typically faster on GPUs while 'channels_last' is
typically faster on GPUs while 'channels_last' is
typically faster on
typically faster on
CPUs. See
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
Returns:
A tf.keras.Model.
A tf.keras.Model.
"""
"""
input_shape
=
None
if
data_format
==
'channels_first'
:
if
data_format
==
'channels_first'
:
input_shape
=
[
1
,
28
,
28
]
input_shape
=
[
1
,
28
,
28
]
else
:
else
:
assert
data_format
==
'channels_last'
assert
data_format
==
'channels_last'
input_shape
=
[
28
,
28
,
1
]
input_shape
=
[
28
,
28
,
1
]
L
=
tf
.
keras
.
layers
l
=
tf
.
keras
.
layers
max_pool
=
L
.
MaxPooling2D
((
2
,
2
),
(
2
,
2
),
padding
=
'same'
,
data_format
=
data_format
)
max_pool
=
l
.
MaxPooling2D
(
return
tf
.
keras
.
Sequential
([
(
2
,
2
),
(
2
,
2
),
padding
=
'same'
,
data_format
=
data_format
)
L
.
Reshape
(
input_shape
),
return
tf
.
keras
.
Sequential
(
L
.
Conv2D
(
32
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
),
[
max_pool
,
l
.
Reshape
(
input_shape
),
L
.
Conv2D
(
64
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
),
l
.
Conv2D
(
max_pool
,
32
,
L
.
Flatten
(),
5
,
L
.
Dense
(
1024
,
activation
=
tf
.
nn
.
relu
),
padding
=
'same'
,
L
.
Dropout
(
0.4
),
data_format
=
data_format
,
L
.
Dense
(
10
)])
activation
=
tf
.
nn
.
relu
),
max_pool
,
l
.
Conv2D
(
64
,
5
,
padding
=
'same'
,
data_format
=
data_format
,
activation
=
tf
.
nn
.
relu
),
max_pool
,
l
.
Flatten
(),
l
.
Dense
(
1024
,
activation
=
tf
.
nn
.
relu
),
l
.
Dropout
(
0.4
),
l
.
Dense
(
10
)
])
def
model_fn
(
features
,
labels
,
mode
,
params
):
def
model_fn
(
features
,
labels
,
mode
,
params
):
...
@@ -122,8 +132,7 @@ def model_fn(features, labels, mode, params):
...
@@ -122,8 +132,7 @@ def model_fn(features, labels, mode, params):
eval_metric_ops
=
{
eval_metric_ops
=
{
'accuracy'
:
'accuracy'
:
tf
.
metrics
.
accuracy
(
tf
.
metrics
.
accuracy
(
labels
=
labels
,
labels
=
labels
,
predictions
=
tf
.
argmax
(
logits
,
axis
=
1
)),
predictions
=
tf
.
argmax
(
logits
,
axis
=
1
)),
})
})
...
@@ -213,8 +222,8 @@ def main(argv):
...
@@ -213,8 +222,8 @@ def main(argv):
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
if
model_helpers
.
past_stop_threshold
(
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
flags
.
stop_threshold
,
eval_results
[
'accuracy'
]):
eval_results
[
'accuracy'
]):
break
break
# Export the model
# Export the model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment