Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bd86e960
"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "8fa4792b804dc2946b982834e2b249994fe9a009"
Commit
bd86e960
authored
Oct 12, 2018
by
Toby Boyd
Browse files
perf_args piped in and add back top_1 and top_5
parent
2894bb53
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
16 deletions
+18
-16
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+8
-6
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+10
-10
No files found.
official/resnet/cifar10_main.py
View file @
bd86e960
...
@@ -109,8 +109,9 @@ def preprocess_image(image, is_training):
...
@@ -109,8 +109,9 @@ def preprocess_image(image, is_training):
return
image
return
image
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
,
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
dtype
=
tf
.
float32
):
dtype
=
tf
.
float32
,
datasets_num_private_threads
=
None
,
num_parallel_batches
=
1
):
"""Input function which provides batches for train or eval.
"""Input function which provides batches for train or eval.
Args:
Args:
...
@@ -118,8 +119,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
...
@@ -118,8 +119,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
data_dir: The directory containing the input data.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
Returns:
Returns:
A dataset that can be used for iteration.
A dataset that can be used for iteration.
...
@@ -134,9 +136,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
...
@@ -134,9 +136,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
shuffle_buffer
=
_NUM_IMAGES
[
'train'
],
shuffle_buffer
=
_NUM_IMAGES
[
'train'
],
parse_record_fn
=
parse_record
,
parse_record_fn
=
parse_record
,
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
num_gpus
=
num_gpus
,
dtype
=
dtype
,
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
,
datasets_num_private_threads
=
datasets_num_private_threads
,
dtype
=
dtype
num_parallel_batches
=
num_parallel_batches
)
)
...
...
official/resnet/resnet_run_loop.py
View file @
bd86e960
...
@@ -431,25 +431,25 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -431,25 +431,25 @@ def resnet_model_fn(features, labels, mode, model_class,
train_op
=
None
train_op
=
None
accuracy
=
tf
.
metrics
.
accuracy
(
labels
,
predictions
[
'classes'
])
accuracy
=
tf
.
metrics
.
accuracy
(
labels
,
predictions
[
'classes'
])
#
accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
accuracy_top_5
=
tf
.
metrics
.
mean
(
tf
.
nn
.
in_top_k
(
predictions
=
logits
,
#
targets=labels,
targets
=
labels
,
#
k=5,
k
=
5
,
#
name='top_5_op'))
name
=
'top_5_op'
))
metrics
=
{
'accuracy'
:
accuracy
}
metrics
=
{
'accuracy'
:
accuracy
,
#
'accuracy_top_5': accuracy_top_5}
'accuracy_top_5'
:
accuracy_top_5
}
# Create a tensor named train_accuracy for logging purposes
# Create a tensor named train_accuracy for logging purposes
tf
.
identity
(
accuracy
[
1
],
name
=
'train_accuracy'
)
tf
.
identity
(
accuracy
[
1
],
name
=
'train_accuracy'
)
#
tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf
.
identity
(
accuracy_top_5
[
1
],
name
=
'train_accuracy_top_5'
)
tf
.
summary
.
scalar
(
'train_accuracy'
,
accuracy
[
1
])
tf
.
summary
.
scalar
(
'train_accuracy'
,
accuracy
[
1
])
#
tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
tf
.
summary
.
scalar
(
'train_accuracy_top_5'
,
accuracy_top_5
[
1
])
return
tf
.
estimator
.
EstimatorSpec
(
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
predictions
=
predictions
,
predictions
=
predictions
,
loss
=
loss
,
loss
=
loss
,
train_op
=
train_op
)
train_op
=
train_op
,
#
eval_metric_ops=metrics)
eval_metric_ops
=
metrics
)
def
resnet_main
(
def
resnet_main
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment