Commit 58037d2c authored by Reed's avatar Reed Committed by Taylor Robie
Browse files

Fix bug where data_async_generation.py would freeze. (#4989)

The data_async_generation.py process would print to stderr, but the main process would redirect it's stderr to a pipe. The main process never read from the pipe, so when the pipe was full, data_async_generation.py would stall on a write to stderr. This change makes data_async_generation.py not write to stdout/stderr.
parent 4acdc508
...@@ -23,7 +23,6 @@ import contextlib ...@@ -23,7 +23,6 @@ import contextlib
import datetime import datetime
import gc import gc
import functools import functools
import logging
import multiprocessing import multiprocessing
import json import json
import os import os
...@@ -40,7 +39,6 @@ import numpy as np ...@@ -40,7 +39,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from absl import app as absl_app from absl import app as absl_app
from absl import logging as absl_logging
from absl import flags from absl import flags
from official.datasets import movielens from official.datasets import movielens
...@@ -48,15 +46,18 @@ from official.recommendation import constants as rconst ...@@ -48,15 +46,18 @@ from official.recommendation import constants as rconst
from official.recommendation import stat_utils from official.recommendation import stat_utils
_log_file = None
def log_msg(msg): def log_msg(msg):
"""Include timestamp info when logging messages to a file.""" """Include timestamp info when logging messages to a file."""
if flags.FLAGS.redirect_logs: if flags.FLAGS.redirect_logs:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
absl_logging.info("[{}] {}".format(timestamp, msg)) print("[{}] {}".format(timestamp, msg), file=_log_file)
else: else:
absl_logging.info(msg) print(msg, file=_log_file)
sys.stdout.flush() if _log_file:
sys.stderr.flush() _log_file.flush()
def get_cycle_folder_name(i): def get_cycle_folder_name(i):
...@@ -395,61 +396,54 @@ def _generation_loop( ...@@ -395,61 +396,54 @@ def _generation_loop(
def main(_): def main(_):
global _log_file
redirect_logs = flags.FLAGS.redirect_logs redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths( cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id) log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_file = os.path.join(cache_paths.data_dir, log_file_name) log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_file.startswith("gs://") and redirect_logs: if log_path.startswith("gs://") and redirect_logs:
fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name) fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
print("Unable to log to {}. Falling back to {}" print("Unable to log to {}. Falling back to {}"
.format(log_file, fallback_log_file)) .format(log_path, fallback_log_file))
log_file = fallback_log_file log_path = fallback_log_file
# This server is generally run in a subprocess. # This server is generally run in a subprocess.
if redirect_logs: if redirect_logs:
print("Redirecting stdout and stderr to {}".format(log_file)) print("Redirecting output of data_async_generation.py process to {}"
log_stream = open(log_file, "wt") # Note: not tf.gfile.Open(). .format(log_path))
stdout = log_stream _log_file = open(log_path, "wt") # Note: not tf.gfile.Open().
stderr = log_stream
try: try:
if redirect_logs: log_msg("sys.argv: {}".format(" ".join(sys.argv)))
absl_logging.get_absl_logger().addHandler(
hdlr=logging.StreamHandler(stream=stdout)) if flags.FLAGS.seed is not None:
sys.stdout = stdout np.random.seed(flags.FLAGS.seed)
sys.stderr = stderr
print("Logs redirected.") _generation_loop(
try: num_workers=flags.FLAGS.num_workers,
log_msg("sys.argv: {}".format(" ".join(sys.argv))) cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
if flags.FLAGS.seed is not None: num_neg=flags.FLAGS.num_neg,
np.random.seed(flags.FLAGS.seed) num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
_generation_loop( spillover=flags.FLAGS.spillover,
num_workers=flags.FLAGS.num_workers, epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
cache_paths=cache_paths, train_batch_size=flags.FLAGS.train_batch_size,
num_readers=flags.FLAGS.num_readers, eval_batch_size=flags.FLAGS.eval_batch_size,
num_neg=flags.FLAGS.num_neg, )
num_train_positives=flags.FLAGS.num_train_positives, except KeyboardInterrupt:
num_items=flags.FLAGS.num_items, log_msg("KeyboardInterrupt registered.")
spillover=flags.FLAGS.spillover, except:
epochs_per_cycle=flags.FLAGS.epochs_per_cycle, traceback.print_exc(file=_log_file)
train_batch_size=flags.FLAGS.train_batch_size, raise
eval_batch_size=flags.FLAGS.eval_batch_size,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
traceback.print_exc()
raise
finally: finally:
log_msg("Shutting down generation subprocess.") log_msg("Shutting down generation subprocess.")
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
if redirect_logs: if redirect_logs:
log_stream.close() _log_file.close()
def define_flags(): def define_flags():
......
...@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
tf.logging.info( tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args))) "Generation subprocess command: {}".format(" ".join(subproc_args)))
proc = subprocess.Popen(args=subproc_args, stdin=subprocess.PIPE, proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=False, env=subproc_env)
atexit.register(_shutdown, proc=proc) atexit.register(_shutdown, proc=proc)
atexit.register(tf.gfile.DeleteRecursively, atexit.register(tf.gfile.DeleteRecursively,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment