Commit 9b7e4163 authored by Shawn Wang's avatar Shawn Wang
Browse files

Allow data async generation to be run as a separate job rather than as a subprocess.

parent 42f98218
...@@ -64,6 +64,8 @@ DUPLICATE_MASK = "duplicate_mask" ...@@ -64,6 +64,8 @@ DUPLICATE_MASK = "duplicate_mask"
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead" CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
# of the main training loop. # of the main training loop.
COMMAND_FILE_TEMP = "command.json.temp"
COMMAND_FILE = "command.json"
READY_FILE_TEMP = "ready.json.temp" READY_FILE_TEMP = "ready.json.temp"
READY_FILE = "ready.json" READY_FILE = "ready.json"
TRAIN_RECORD_TEMPLATE = "train_{}.tfrecords" TRAIN_RECORD_TEMPLATE = "train_{}.tfrecords"
......
...@@ -50,6 +50,10 @@ _log_file = None ...@@ -50,6 +50,10 @@ _log_file = None
def log_msg(msg): def log_msg(msg):
"""Include timestamp info when logging messages to a file.""" """Include timestamp info when logging messages to a file."""
if flags.FLAGS.use_command_file:
tf.logging.info(msg)
return
if flags.FLAGS.redirect_logs: if flags.FLAGS.redirect_logs:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
print("[{}] {}".format(timestamp, msg), file=_log_file) print("[{}] {}".format(timestamp, msg), file=_log_file)
...@@ -207,8 +211,7 @@ def _construct_training_records( ...@@ -207,8 +211,7 @@ def _construct_training_records(
map_args = [(shard, num_items, num_neg, process_seeds[i]) map_args = [(shard, num_items, num_neg, process_seeds[i])
for i, shard in enumerate(training_shards * epochs_per_cycle)] for i, shard in enumerate(training_shards * epochs_per_cycle)]
with contextlib.closing(multiprocessing.Pool( with popen_helper.get_pool(num_workers, init_worker) as pool:
processes=num_workers, initializer=init_worker)) as pool:
map_fn = pool.imap if deterministic else pool.imap_unordered # pylint: disable=no-member map_fn = pool.imap if deterministic else pool.imap_unordered # pylint: disable=no-member
data_generator = map_fn(_process_shard, map_args) data_generator = map_fn(_process_shard, map_args)
data = [ data = [
...@@ -436,8 +439,39 @@ def _generation_loop(num_workers, # type: int ...@@ -436,8 +439,39 @@ def _generation_loop(num_workers, # type: int
gc.collect() gc.collect()
def _set_flags_with_command_file():
"""Use arguments from COMMAND_FILE when use_command_file is True."""
command_file = os.path.join(flags.FLAGS.data_dir,
rconst.COMMAND_FILE)
tf.logging.info("Waiting for command file to appear at {}..."
.format(command_file))
while not tf.gfile.Exists(command_file):
time.sleep(1)
tf.logging.info("Command file found.")
with tf.gfile.Open(command_file, "r") as f:
command = json.load(f)
flags.FLAGS.num_workers = command["num_workers"]
assert flags.FLAGS.data_dir == command["data_dir"]
flags.FLAGS.cache_id = command["cache_id"]
flags.FLAGS.num_readers = command["num_readers"]
flags.FLAGS.num_neg = command["num_neg"]
flags.FLAGS.num_train_positives = command["num_train_positives"]
flags.FLAGS.num_items = command["num_items"]
flags.FLAGS.epochs_per_cycle = command["epochs_per_cycle"]
flags.FLAGS.train_batch_size = command["train_batch_size"]
flags.FLAGS.eval_batch_size = command["eval_batch_size"]
flags.FLAGS.spillover = command["spillover"]
flags.FLAGS.redirect_logs = command["redirect_logs"]
assert flags.FLAGS.redirect_logs is False
if "seed" in command:
flags.FLAGS.seed = command["seed"]
def main(_): def main(_):
global _log_file global _log_file
if flags.FLAGS.use_command_file is not None:
_set_flags_with_command_file()
redirect_logs = flags.FLAGS.redirect_logs redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths( cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
...@@ -489,16 +523,12 @@ def main(_): ...@@ -489,16 +523,12 @@ def main(_):
def define_flags(): def define_flags():
"""Construct flags for the server. """Construct flags for the server."""
This function does not use offical.utils.flags, as these flags are not meant
to be used by humans. Rather, they should be passed as part of a subprocess
call.
"""
flags.DEFINE_integer(name="num_workers", default=multiprocessing.cpu_count(), flags.DEFINE_integer(name="num_workers", default=multiprocessing.cpu_count(),
help="Size of the negative generation worker pool.") help="Size of the negative generation worker pool.")
flags.DEFINE_string(name="data_dir", default=None, flags.DEFINE_string(name="data_dir", default=None,
help="The data root. (used to construct cache paths.)") help="The data root. (used to construct cache paths.)")
flags.mark_flags_as_required(["data_dir"])
flags.DEFINE_string(name="cache_id", default=None, flags.DEFINE_string(name="cache_id", default=None,
help="The cache_id generated in the main process.") help="The cache_id generated in the main process.")
flags.DEFINE_integer(name="num_readers", default=4, flags.DEFINE_integer(name="num_readers", default=4,
...@@ -531,11 +561,9 @@ def define_flags(): ...@@ -531,11 +561,9 @@ def define_flags():
flags.DEFINE_integer(name="seed", default=None, flags.DEFINE_integer(name="seed", default=None,
help="NumPy random seed to set at startup. If not " help="NumPy random seed to set at startup. If not "
"specified, a seed will not be set.") "specified, a seed will not be set.")
flags.DEFINE_boolean(name="use_command_file", default=False,
flags.mark_flags_as_required( help="Use command arguments from json at command_path. "
["data_dir", "cache_id", "num_neg", "num_train_positives", "num_items", "All arguments other than data_dir will be ignored.")
"train_batch_size", "eval_batch_size"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -416,7 +416,8 @@ def _shutdown(proc): ...@@ -416,7 +416,8 @@ def _shutdown(proc):
def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=None, num_neg=4, epochs_per_cycle=1, num_data_readers=None, num_neg=4, epochs_per_cycle=1,
match_mlperf=False, deterministic=False): match_mlperf=False, deterministic=False,
use_subprocess=True):
# type: (...) -> (NCFDataset, typing.Callable) # type: (...) -> (NCFDataset, typing.Callable)
"""Preprocess data and start negative generation subprocess.""" """Preprocess data and start negative generation subprocess."""
...@@ -425,7 +426,11 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -425,7 +426,11 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=num_data_readers, num_data_readers=num_data_readers,
match_mlperf=match_mlperf, match_mlperf=match_mlperf,
deterministic=deterministic) deterministic=deterministic)
# By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes.
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
if use_subprocess:
tf.logging.info("Creating training file subprocess.") tf.logging.info("Creating training file subprocess.")
subproc_env = os.environ.copy() subproc_env = os.environ.copy()
...@@ -435,10 +440,6 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -435,10 +440,6 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# contention with the main training process. # contention with the main training process.
subproc_env["CUDA_VISIBLE_DEVICES"] = "" subproc_env["CUDA_VISIBLE_DEVICES"] = ""
# By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes.
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
subproc_args = popen_helper.INVOCATION + [ subproc_args = popen_helper.INVOCATION + [
"--data_dir", data_dir, "--data_dir", data_dir,
"--cache_id", str(ncf_dataset.cache_paths.cache_id), "--cache_id", str(ncf_dataset.cache_paths.cache_id),
...@@ -450,10 +451,10 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -450,10 +451,10 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
"--train_batch_size", str(batch_size), "--train_batch_size", str(batch_size),
"--eval_batch_size", str(eval_batch_size), "--eval_batch_size", str(eval_batch_size),
"--num_workers", str(num_workers), "--num_workers", str(num_workers),
"--spillover", "True", # This allows the training input function to # This allows the training input function to guarantee batch size and
# guarantee batch size and significantly improves # significantly improves performance. (~5% increase in examples/sec on
# performance. (~5% increase in examples/sec on
# GPU, and needed for TPU XLA.) # GPU, and needed for TPU XLA.)
"--spillover", "True",
"--redirect_logs", "True" "--redirect_logs", "True"
] ]
if ncf_dataset.deterministic: if ncf_dataset.deterministic:
...@@ -464,6 +465,42 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -464,6 +465,42 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env) proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
else:
# We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file.
command_file_temp = os.path.join(data_dir, rconst.COMMAND_FILE_TEMP)
tf.logging.info("Generation subprocess command at {} ..."
.format(command_file_temp))
with tf.gfile.Open(command_file_temp, "w") as f:
command = {
"data_dir": data_dir,
"cache_id": ncf_dataset.cache_paths.cache_id,
"num_neg": num_neg,
"num_train_positives": ncf_dataset.num_train_positives,
"num_items": ncf_dataset.num_items,
"num_readers": ncf_dataset.num_data_readers,
"epochs_per_cycle": epochs_per_cycle,
"train_batch_size": batch_size,
"eval_batch_size": eval_batch_size,
"num_workers": num_workers,
# This allows the training input function to guarantee batch size and
# significantly improves performance. (~5% increase in examples/sec on
# GPU, and needed for TPU XLA.)
"spillover": True,
"redirect_logs": False
}
if ncf_dataset.deterministic:
command["seed"] = stat_utils.random_int32()
json.dump(command, f)
command_file = os.path.join(data_dir, rconst.COMMAND_FILE)
tf.gfile.Rename(command_file_temp, command_file)
tf.logging.info(
"Generation subprocess command saved to: {}"
.format(command_file))
cleanup_called = {"finished": False} cleanup_called = {"finished": False}
@atexit.register @atexit.register
def cleanup(): def cleanup():
...@@ -471,7 +508,9 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -471,7 +508,9 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
if cleanup_called["finished"]: if cleanup_called["finished"]:
return return
if use_subprocess:
_shutdown(proc) _shutdown(proc)
try: try:
tf.gfile.DeleteRecursively(ncf_dataset.cache_paths.cache_root) tf.gfile.DeleteRecursively(ncf_dataset.cache_paths.cache_root)
except tf.errors.NotFoundError: except tf.errors.NotFoundError:
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Helper file for running the async data generation process in OSS.""" """Helper file for running the async data generation process in OSS."""
import contextlib
import multiprocessing
import os import os
import sys import sys
...@@ -27,3 +29,8 @@ _ASYNC_GEN_PATH = os.path.join(os.path.dirname(__file__), ...@@ -27,3 +29,8 @@ _ASYNC_GEN_PATH = os.path.join(os.path.dirname(__file__),
"data_async_generation.py") "data_async_generation.py")
INVOCATION = [_PYTHON, _ASYNC_GEN_PATH] INVOCATION = [_PYTHON, _ASYNC_GEN_PATH]
def get_pool(num_workers, init_worker=None):
return contextlib.closing(multiprocessing.Pool(
processes=num_workers, initializer=init_worker))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment