Commit 58037d2c authored by Reed's avatar Reed Committed by Taylor Robie
Browse files

Fix bug where data_async_generation.py would freeze. (#4989)

The data_async_generation.py process would print to stderr, but the main process would redirect it's stderr to a pipe. The main process never read from the pipe, so when the pipe was full, data_async_generation.py would stall on a write to stderr. This change makes data_async_generation.py not write to stdout/stderr.
parent 4acdc508
......@@ -23,7 +23,6 @@ import contextlib
import datetime
import gc
import functools
import logging
import multiprocessing
import json
import os
......@@ -40,7 +39,6 @@ import numpy as np
import tensorflow as tf
from absl import app as absl_app
from absl import logging as absl_logging
from absl import flags
from official.datasets import movielens
......@@ -48,15 +46,18 @@ from official.recommendation import constants as rconst
from official.recommendation import stat_utils
_log_file = None
def log_msg(msg):
"""Include timestamp info when logging messages to a file."""
if flags.FLAGS.redirect_logs:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
absl_logging.info("[{}] {}".format(timestamp, msg))
print("[{}] {}".format(timestamp, msg), file=_log_file)
else:
absl_logging.info(msg)
sys.stdout.flush()
sys.stderr.flush()
print(msg, file=_log_file)
if _log_file:
_log_file.flush()
def get_cycle_folder_name(i):
......@@ -395,61 +396,54 @@ def _generation_loop(
def main(_):
global _log_file
redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_file = os.path.join(cache_paths.data_dir, log_file_name)
if log_file.startswith("gs://") and redirect_logs:
log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_path.startswith("gs://") and redirect_logs:
fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
print("Unable to log to {}. Falling back to {}"
.format(log_file, fallback_log_file))
log_file = fallback_log_file
.format(log_path, fallback_log_file))
log_path = fallback_log_file
# This server is generally run in a subprocess.
if redirect_logs:
print("Redirecting stdout and stderr to {}".format(log_file))
log_stream = open(log_file, "wt") # Note: not tf.gfile.Open().
stdout = log_stream
stderr = log_stream
print("Redirecting output of data_async_generation.py process to {}"
.format(log_path))
_log_file = open(log_path, "wt") # Note: not tf.gfile.Open().
try:
if redirect_logs:
absl_logging.get_absl_logger().addHandler(
hdlr=logging.StreamHandler(stream=stdout))
sys.stdout = stdout
sys.stderr = stderr
print("Logs redirected.")
try:
log_msg("sys.argv: {}".format(" ".join(sys.argv)))
if flags.FLAGS.seed is not None:
np.random.seed(flags.FLAGS.seed)
_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
spillover=flags.FLAGS.spillover,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
traceback.print_exc()
raise
log_msg("sys.argv: {}".format(" ".join(sys.argv)))
if flags.FLAGS.seed is not None:
np.random.seed(flags.FLAGS.seed)
_generation_loop(
num_workers=flags.FLAGS.num_workers,
cache_paths=cache_paths,
num_readers=flags.FLAGS.num_readers,
num_neg=flags.FLAGS.num_neg,
num_train_positives=flags.FLAGS.num_train_positives,
num_items=flags.FLAGS.num_items,
spillover=flags.FLAGS.spillover,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size,
)
except KeyboardInterrupt:
log_msg("KeyboardInterrupt registered.")
except:
traceback.print_exc(file=_log_file)
raise
finally:
log_msg("Shutting down generation subprocess.")
sys.stdout.flush()
sys.stderr.flush()
if redirect_logs:
log_stream.close()
_log_file.close()
def define_flags():
......
......@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args)))
proc = subprocess.Popen(args=subproc_args, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=False, env=subproc_env)
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
atexit.register(_shutdown, proc=proc)
atexit.register(tf.gfile.DeleteRecursively,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment