Commit 58037d2c authored by Reed's avatar Reed Committed by Taylor Robie
Browse files

Fix bug where data_async_generation.py would freeze. (#4989)

The data_async_generation.py process would print to stderr, but the main process would redirect it's stderr to a pipe. The main process never read from the pipe, so when the pipe was full, data_async_generation.py would stall on a write to stderr. This change makes data_async_generation.py not write to stdout/stderr.
parent 4acdc508
...@@ -23,7 +23,6 @@ import contextlib ...@@ -23,7 +23,6 @@ import contextlib
import datetime import datetime
import gc import gc
import functools import functools
import logging
import multiprocessing import multiprocessing
import json import json
import os import os
...@@ -40,7 +39,6 @@ import numpy as np ...@@ -40,7 +39,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from absl import app as absl_app from absl import app as absl_app
from absl import logging as absl_logging
from absl import flags from absl import flags
from official.datasets import movielens from official.datasets import movielens
...@@ -48,15 +46,18 @@ from official.recommendation import constants as rconst ...@@ -48,15 +46,18 @@ from official.recommendation import constants as rconst
from official.recommendation import stat_utils from official.recommendation import stat_utils
_log_file = None
def log_msg(msg): def log_msg(msg):
"""Include timestamp info when logging messages to a file.""" """Include timestamp info when logging messages to a file."""
if flags.FLAGS.redirect_logs: if flags.FLAGS.redirect_logs:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
absl_logging.info("[{}] {}".format(timestamp, msg)) print("[{}] {}".format(timestamp, msg), file=_log_file)
else: else:
absl_logging.info(msg) print(msg, file=_log_file)
sys.stdout.flush() if _log_file:
sys.stderr.flush() _log_file.flush()
def get_cycle_folder_name(i): def get_cycle_folder_name(i):
...@@ -395,32 +396,25 @@ def _generation_loop( ...@@ -395,32 +396,25 @@ def _generation_loop(
def main(_): def main(_):
global _log_file
redirect_logs = flags.FLAGS.redirect_logs redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths( cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id) log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_file = os.path.join(cache_paths.data_dir, log_file_name) log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_file.startswith("gs://") and redirect_logs: if log_path.startswith("gs://") and redirect_logs:
fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name) fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
print("Unable to log to {}. Falling back to {}" print("Unable to log to {}. Falling back to {}"
.format(log_file, fallback_log_file)) .format(log_path, fallback_log_file))
log_file = fallback_log_file log_path = fallback_log_file
# This server is generally run in a subprocess. # This server is generally run in a subprocess.
if redirect_logs: if redirect_logs:
print("Redirecting stdout and stderr to {}".format(log_file)) print("Redirecting output of data_async_generation.py process to {}"
log_stream = open(log_file, "wt") # Note: not tf.gfile.Open(). .format(log_path))
stdout = log_stream _log_file = open(log_path, "wt") # Note: not tf.gfile.Open().
stderr = log_stream
try:
if redirect_logs:
absl_logging.get_absl_logger().addHandler(
hdlr=logging.StreamHandler(stream=stdout))
sys.stdout = stdout
sys.stderr = stderr
print("Logs redirected.")
try: try:
log_msg("sys.argv: {}".format(" ".join(sys.argv))) log_msg("sys.argv: {}".format(" ".join(sys.argv)))
...@@ -442,14 +436,14 @@ def main(_): ...@@ -442,14 +436,14 @@ def main(_):
except KeyboardInterrupt: except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.") log_msg("KeyboardInterrupt registered.")
except: except:
traceback.print_exc() traceback.print_exc(file=_log_file)
raise raise
finally: finally:
log_msg("Shutting down generation subprocess.") log_msg("Shutting down generation subprocess.")
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
if redirect_logs: if redirect_logs:
log_stream.close() _log_file.close()
def define_flags(): def define_flags():
......
...@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
tf.logging.info( tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args))) "Generation subprocess command: {}".format(" ".join(subproc_args)))
proc = subprocess.Popen(args=subproc_args, stdin=subprocess.PIPE, proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=False, env=subproc_env)
atexit.register(_shutdown, proc=proc) atexit.register(_shutdown, proc=proc)
atexit.register(tf.gfile.DeleteRecursively, atexit.register(tf.gfile.DeleteRecursively,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment