"src/runtime/vscode:/vscode.git/clone" did not exist on "cb2327d4a93c042f3c6fe1c42fe8f4c31f087e3b"
Unverified Commit 91b2debd authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Make flagfile sharing robust to distributed filesystems and multi-worker setups. (#5521)

* move flagfile into the cache_dir

* remove duplicate code

* delint
parent 0c5c3a77
......@@ -440,10 +440,8 @@ def _generation_loop(num_workers, # type: int
gc.collect()
def _parse_flagfile():
def _parse_flagfile(flagfile):
"""Fill flags with flagfile written by the main process."""
flagfile = os.path.join(flags.FLAGS.data_dir,
rconst.FLAGFILE)
tf.logging.info("Waiting for flagfile to appear at {}..."
.format(flagfile))
start_time = time.time()
......@@ -455,18 +453,26 @@ def _parse_flagfile():
sys.exit()
time.sleep(1)
tf.logging.info("flagfile found.")
# This overrides FLAGS with flags from flagfile.
flags.FLAGS([__file__, "--flagfile", flagfile])
# `flags` module opens `flagfile` with `open`, which does not work on
# google cloud storage etc.
_, flagfile_temp = tempfile.mkstemp()
tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)
flags.FLAGS([__file__, "--flagfile", flagfile_temp])
tf.gfile.Remove(flagfile_temp)
def main(_):
global _log_file
_parse_flagfile()
redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
_parse_flagfile(flagfile)
redirect_logs = flags.FLAGS.redirect_logs
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_path.startswith("gs://") and redirect_logs:
......@@ -518,7 +524,6 @@ def define_flags():
help="Size of the negative generation worker pool.")
flags.DEFINE_string(name="data_dir", default=None,
help="The data root. (used to construct cache paths.)")
flags.mark_flags_as_required(["data_dir"])
flags.DEFINE_string(name="cache_id", default=None,
help="The cache_id generated in the main process.")
flags.DEFINE_integer(name="num_readers", default=4,
......@@ -554,6 +559,7 @@ def define_flags():
help="NumPy random seed to set at startup. If not "
"specified, a seed will not be set.")
flags.mark_flags_as_required(["data_dir", "cache_id"])
if __name__ == "__main__":
define_flags()
......
......@@ -357,8 +357,8 @@ def generate_train_eval_data(df, approx_num_shards, num_items, cache_paths,
def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
deterministic):
# type: (str, str, int, bool) -> NCFDataset
deterministic, cache_id=None):
# type: (str, str, int, bool, typing.Optional[int]) -> NCFDataset
"""Load and digest data CSV into a usable form.
Args:
......@@ -371,7 +371,7 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
deterministic: Try to enforce repeatable behavior, even at the cost of
performance.
"""
cache_paths = rconst.Paths(data_dir=data_dir)
cache_paths = rconst.Paths(data_dir=data_dir, cache_id=cache_id)
num_data_readers = (num_data_readers or int(multiprocessing.cpu_count() / 2)
or 1)
approx_num_shards = int(movielens.NUM_RATINGS[dataset]
......@@ -436,7 +436,7 @@ def _shutdown(proc):
def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=None, num_neg=4, epochs_per_cycle=1,
match_mlperf=False, deterministic=False,
use_subprocess=True):
use_subprocess=True, cache_id=None):
# type: (...) -> (NCFDataset, typing.Callable)
"""Preprocess data and start negative generation subprocess."""
......@@ -444,7 +444,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
ncf_dataset = construct_cache(dataset=dataset, data_dir=data_dir,
num_data_readers=num_data_readers,
match_mlperf=match_mlperf,
deterministic=deterministic)
deterministic=deterministic,
cache_id=cache_id)
# By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes.
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
......@@ -473,13 +474,14 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file.
flagfile_temp = os.path.join(data_dir, rconst.FLAGFILE_TEMP)
flagfile_temp = os.path.join(ncf_dataset.cache_paths.cache_root,
rconst.FLAGFILE_TEMP)
tf.logging.info("Preparing flagfile for async data generation in {} ..."
.format(flagfile_temp))
with tf.gfile.Open(flagfile_temp, "w") as f:
for k, v in six.iteritems(flags_):
f.write("--{}={}\n".format(k, v))
flagfile = os.path.join(data_dir, rconst.FLAGFILE)
flagfile = os.path.join(ncf_dataset.cache_paths.cache_root, rconst.FLAGFILE)
tf.gfile.Rename(flagfile_temp, flagfile)
tf.logging.info(
"Wrote flagfile for async data generation in {}."
......@@ -493,7 +495,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# contention with the main training process.
subproc_env["CUDA_VISIBLE_DEVICES"] = ""
subproc_args = popen_helper.INVOCATION + [
"--data_dir", data_dir]
"--data_dir", data_dir,
"--cache_id", str(ncf_dataset.cache_paths.cache_id)]
tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args)))
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
......
......@@ -152,7 +152,8 @@ def run_ncf(_):
epochs_per_cycle=FLAGS.epochs_between_evals,
match_mlperf=FLAGS.ml_perf,
deterministic=FLAGS.seed is not None,
use_subprocess=FLAGS.use_subprocess)
use_subprocess=FLAGS.use_subprocess,
cache_id=FLAGS.cache_id)
num_users = ncf_dataset.num_users
num_items = ncf_dataset.num_items
approx_train_steps = int(ncf_dataset.num_train_positives
......@@ -387,6 +388,12 @@ def define_ncf_flags():
"subprocess. If set to False, ncf_main.py will assume the async data "
"generation process has already been started by the user."))
flags.DEFINE_integer(name="cache_id", default=None, help=flags_core.help_wrap(
"Use a specified cache_id rather than using a timestamp. This is only "
"needed to synchronize across multiple workers. Generally this flag will "
"not need to be set."
))
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment