Commit 1024f926 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix eval writing multiple summary files.

PiperOrigin-RevId: 171498799
parent f51da4bb
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries
# ==============================================================================
py_binary(
......
......@@ -16,7 +16,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports
......
......@@ -25,7 +25,7 @@ import time
import tensorflow as tf
import graphs
from adversarial_text import graphs
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -75,7 +75,8 @@ def run_eval(eval_ops, summary_writer, saver):
Returns:
dict<metric name, value>, with value being the average over all examples.
"""
sv = tf.train.Supervisor(logdir=FLAGS.eval_dir, saver=None, summary_op=None)
sv = tf.train.Supervisor(
logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
with sv.managed_session(
master=FLAGS.master, start_standard_services=False) as sess:
if not restore_from_checkpoint(sess, saver):
......@@ -113,6 +114,7 @@ def _log_values(sess, value_ops, summary_writer=None):
if summary_writer is not None:
global_step_val = sess.run(tf.train.get_global_step())
tf.logging.info('Finished eval for step ' + str(global_step_val))
summary_writer.add_summary(summary, global_step_val)
......
......@@ -24,9 +24,9 @@ import os
import tensorflow as tf
import adversarial_losses as adv_lib
import inputs as inputs_lib
import layers as layers_lib
from adversarial_text import adversarial_losses as adv_lib
from adversarial_text import inputs as inputs_lib
from adversarial_text import layers as layers_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -116,7 +116,7 @@ class VatxtModel(object):
"""
def __init__(self, cl_logits_input_dim=None):
self.global_step = tf.train.get_or_create_global_step()
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.vocab_freqs = _get_vocab_freqs()
# Cache VatxtInput objects
......
......@@ -29,7 +29,7 @@ import tempfile
import tensorflow as tf
import graphs
from adversarial_text import graphs
from adversarial_text.data import data_utils
flags = tf.app.flags
......
......@@ -16,7 +16,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports
......
......@@ -27,8 +27,8 @@ from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
from adversarial_text import graphs
from adversarial_text import train_utils
FLAGS = tf.app.flags.FLAGS
......
......@@ -35,8 +35,8 @@ from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
from adversarial_text import graphs
from adversarial_text import train_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment