"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "8fa4792b804dc2946b982834e2b249994fe9a009"
Commit bd86e960 authored by Toby Boyd's avatar Toby Boyd
Browse files

perf_args piped in and add back top_1 and top_5

parent 2894bb53
...@@ -109,8 +109,9 @@ def preprocess_image(image, is_training): ...@@ -109,8 +109,9 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32): dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -118,8 +119,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None, ...@@ -118,8 +119,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -134,9 +136,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None, ...@@ -134,9 +136,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
shuffle_buffer=_NUM_IMAGES['train'], shuffle_buffer=_NUM_IMAGES['train'],
parse_record_fn=parse_record, parse_record_fn=parse_record,
num_epochs=num_epochs, num_epochs=num_epochs,
num_gpus=num_gpus, dtype=dtype,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None, datasets_num_private_threads=datasets_num_private_threads,
dtype=dtype num_parallel_batches=num_parallel_batches
) )
......
...@@ -431,25 +431,25 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -431,25 +431,25 @@ def resnet_model_fn(features, labels, mode, model_class,
train_op = None train_op = None
accuracy = tf.metrics.accuracy(labels, predictions['classes']) accuracy = tf.metrics.accuracy(labels, predictions['classes'])
#accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits, accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
# targets=labels, targets=labels,
# k=5, k=5,
# name='top_5_op')) name='top_5_op'))
metrics = {'accuracy': accuracy} metrics = {'accuracy': accuracy,
# 'accuracy_top_5': accuracy_top_5} 'accuracy_top_5': accuracy_top_5}
# Create a tensor named train_accuracy for logging purposes # Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy') tf.identity(accuracy[1], name='train_accuracy')
#tf.identity(accuracy_top_5[1], name='train_accuracy_top_5') tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.summary.scalar('train_accuracy', accuracy[1]) tf.summary.scalar('train_accuracy', accuracy[1])
#tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1]) tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, mode=mode,
predictions=predictions, predictions=predictions,
loss=loss, loss=loss,
train_op=train_op) train_op=train_op,
#eval_metric_ops=metrics) eval_metric_ops=metrics)
def resnet_main( def resnet_main(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment