Commit 1024f926 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix eval writing multiple summary files.

PiperOrigin-RevId: 171498799
parent f51da4bb
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries # Binaries
# ============================================================================== # ==============================================================================
py_binary( py_binary(
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports # Dependency imports
......
...@@ -25,7 +25,7 @@ import time ...@@ -25,7 +25,7 @@ import time
import tensorflow as tf import tensorflow as tf
import graphs from adversarial_text import graphs
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -75,7 +75,8 @@ def run_eval(eval_ops, summary_writer, saver): ...@@ -75,7 +75,8 @@ def run_eval(eval_ops, summary_writer, saver):
Returns: Returns:
dict<metric name, value>, with value being the average over all examples. dict<metric name, value>, with value being the average over all examples.
""" """
sv = tf.train.Supervisor(logdir=FLAGS.eval_dir, saver=None, summary_op=None) sv = tf.train.Supervisor(
logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
with sv.managed_session( with sv.managed_session(
master=FLAGS.master, start_standard_services=False) as sess: master=FLAGS.master, start_standard_services=False) as sess:
if not restore_from_checkpoint(sess, saver): if not restore_from_checkpoint(sess, saver):
...@@ -113,6 +114,7 @@ def _log_values(sess, value_ops, summary_writer=None): ...@@ -113,6 +114,7 @@ def _log_values(sess, value_ops, summary_writer=None):
if summary_writer is not None: if summary_writer is not None:
global_step_val = sess.run(tf.train.get_global_step()) global_step_val = sess.run(tf.train.get_global_step())
tf.logging.info('Finished eval for step ' + str(global_step_val))
summary_writer.add_summary(summary, global_step_val) summary_writer.add_summary(summary, global_step_val)
......
...@@ -24,9 +24,9 @@ import os ...@@ -24,9 +24,9 @@ import os
import tensorflow as tf import tensorflow as tf
import adversarial_losses as adv_lib from adversarial_text import adversarial_losses as adv_lib
import inputs as inputs_lib from adversarial_text import inputs as inputs_lib
import layers as layers_lib from adversarial_text import layers as layers_lib
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -116,7 +116,7 @@ class VatxtModel(object): ...@@ -116,7 +116,7 @@ class VatxtModel(object):
""" """
def __init__(self, cl_logits_input_dim=None): def __init__(self, cl_logits_input_dim=None):
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.contrib.framework.get_or_create_global_step()
self.vocab_freqs = _get_vocab_freqs() self.vocab_freqs = _get_vocab_freqs()
# Cache VatxtInput objects # Cache VatxtInput objects
......
...@@ -29,7 +29,7 @@ import tempfile ...@@ -29,7 +29,7 @@ import tempfile
import tensorflow as tf import tensorflow as tf
import graphs from adversarial_text import graphs
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
flags = tf.app.flags flags = tf.app.flags
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports # Dependency imports
......
...@@ -27,8 +27,8 @@ from __future__ import print_function ...@@ -27,8 +27,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import graphs from adversarial_text import graphs
import train_utils from adversarial_text import train_utils
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
......
...@@ -35,8 +35,8 @@ from __future__ import print_function ...@@ -35,8 +35,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import graphs from adversarial_text import graphs
import train_utils from adversarial_text import train_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment