"mmdet3d/vscode:/vscode.git/clone" did not exist on "6aa820eec8bd141c00fd4bd881be0736bacf4ad2"
Commit bd86e960 authored by Toby Boyd's avatar Toby Boyd
Browse files

perf_args piped in and add back top_1 and top_5

parent 2894bb53
......@@ -109,8 +109,9 @@ def preprocess_image(image, is_training):
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
dtype=tf.float32):
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1):
"""Input function which provides batches for train or eval.
Args:
......@@ -118,8 +119,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
Returns:
A dataset that can be used for iteration.
......@@ -134,9 +136,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
shuffle_buffer=_NUM_IMAGES['train'],
parse_record_fn=parse_record,
num_epochs=num_epochs,
num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
dtype=dtype
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches
)
......
......@@ -431,25 +431,25 @@ def resnet_model_fn(features, labels, mode, model_class,
train_op = None
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
#accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
# targets=labels,
# k=5,
# name='top_5_op'))
metrics = {'accuracy': accuracy}
# 'accuracy_top_5': accuracy_top_5}
accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
targets=labels,
k=5,
name='top_5_op'))
metrics = {'accuracy': accuracy,
'accuracy_top_5': accuracy_top_5}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
#tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.summary.scalar('train_accuracy', accuracy[1])
#tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
#eval_metric_ops=metrics)
train_op=train_op,
eval_metric_ops=metrics)
def resnet_main(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment