"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4e898560cefb86525a65d32a662d6b3d6b2b0b82"
Unverified Commit 91b2debd authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Make flagfile sharing robust to distributed filesystems and multi-worker setups. (#5521)

* move flagfile into the cache_dir

* remove duplicate code

* delint
parent 0c5c3a77
...@@ -440,10 +440,8 @@ def _generation_loop(num_workers, # type: int ...@@ -440,10 +440,8 @@ def _generation_loop(num_workers, # type: int
gc.collect() gc.collect()
def _parse_flagfile(): def _parse_flagfile(flagfile):
"""Fill flags with flagfile written by the main process.""" """Fill flags with flagfile written by the main process."""
flagfile = os.path.join(flags.FLAGS.data_dir,
rconst.FLAGFILE)
tf.logging.info("Waiting for flagfile to appear at {}..." tf.logging.info("Waiting for flagfile to appear at {}..."
.format(flagfile)) .format(flagfile))
start_time = time.time() start_time = time.time()
...@@ -455,18 +453,26 @@ def _parse_flagfile(): ...@@ -455,18 +453,26 @@ def _parse_flagfile():
sys.exit() sys.exit()
time.sleep(1) time.sleep(1)
tf.logging.info("flagfile found.") tf.logging.info("flagfile found.")
# This overrides FLAGS with flags from flagfile.
flags.FLAGS([__file__, "--flagfile", flagfile]) # `flags` module opens `flagfile` with `open`, which does not work on
# google cloud storage etc.
_, flagfile_temp = tempfile.mkstemp()
tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)
flags.FLAGS([__file__, "--flagfile", flagfile_temp])
tf.gfile.Remove(flagfile_temp)
def main(_): def main(_):
global _log_file global _log_file
_parse_flagfile()
redirect_logs = flags.FLAGS.redirect_logs
cache_paths = rconst.Paths( cache_paths = rconst.Paths(
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
_parse_flagfile(flagfile)
redirect_logs = flags.FLAGS.redirect_logs
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id) log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
log_path = os.path.join(cache_paths.data_dir, log_file_name) log_path = os.path.join(cache_paths.data_dir, log_file_name)
if log_path.startswith("gs://") and redirect_logs: if log_path.startswith("gs://") and redirect_logs:
...@@ -518,7 +524,6 @@ def define_flags(): ...@@ -518,7 +524,6 @@ def define_flags():
help="Size of the negative generation worker pool.") help="Size of the negative generation worker pool.")
flags.DEFINE_string(name="data_dir", default=None, flags.DEFINE_string(name="data_dir", default=None,
help="The data root. (used to construct cache paths.)") help="The data root. (used to construct cache paths.)")
flags.mark_flags_as_required(["data_dir"])
flags.DEFINE_string(name="cache_id", default=None, flags.DEFINE_string(name="cache_id", default=None,
help="The cache_id generated in the main process.") help="The cache_id generated in the main process.")
flags.DEFINE_integer(name="num_readers", default=4, flags.DEFINE_integer(name="num_readers", default=4,
...@@ -554,6 +559,7 @@ def define_flags(): ...@@ -554,6 +559,7 @@ def define_flags():
help="NumPy random seed to set at startup. If not " help="NumPy random seed to set at startup. If not "
"specified, a seed will not be set.") "specified, a seed will not be set.")
flags.mark_flags_as_required(["data_dir", "cache_id"])
if __name__ == "__main__": if __name__ == "__main__":
define_flags() define_flags()
......
...@@ -357,8 +357,8 @@ def generate_train_eval_data(df, approx_num_shards, num_items, cache_paths, ...@@ -357,8 +357,8 @@ def generate_train_eval_data(df, approx_num_shards, num_items, cache_paths,
def construct_cache(dataset, data_dir, num_data_readers, match_mlperf, def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
deterministic): deterministic, cache_id=None):
# type: (str, str, int, bool) -> NCFDataset # type: (str, str, int, bool, typing.Optional[int]) -> NCFDataset
"""Load and digest data CSV into a usable form. """Load and digest data CSV into a usable form.
Args: Args:
...@@ -371,7 +371,7 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf, ...@@ -371,7 +371,7 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
deterministic: Try to enforce repeatable behavior, even at the cost of deterministic: Try to enforce repeatable behavior, even at the cost of
performance. performance.
""" """
cache_paths = rconst.Paths(data_dir=data_dir) cache_paths = rconst.Paths(data_dir=data_dir, cache_id=cache_id)
num_data_readers = (num_data_readers or int(multiprocessing.cpu_count() / 2) num_data_readers = (num_data_readers or int(multiprocessing.cpu_count() / 2)
or 1) or 1)
approx_num_shards = int(movielens.NUM_RATINGS[dataset] approx_num_shards = int(movielens.NUM_RATINGS[dataset]
...@@ -436,7 +436,7 @@ def _shutdown(proc): ...@@ -436,7 +436,7 @@ def _shutdown(proc):
def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=None, num_neg=4, epochs_per_cycle=1, num_data_readers=None, num_neg=4, epochs_per_cycle=1,
match_mlperf=False, deterministic=False, match_mlperf=False, deterministic=False,
use_subprocess=True): use_subprocess=True, cache_id=None):
# type: (...) -> (NCFDataset, typing.Callable) # type: (...) -> (NCFDataset, typing.Callable)
"""Preprocess data and start negative generation subprocess.""" """Preprocess data and start negative generation subprocess."""
...@@ -444,7 +444,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -444,7 +444,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
ncf_dataset = construct_cache(dataset=dataset, data_dir=data_dir, ncf_dataset = construct_cache(dataset=dataset, data_dir=data_dir,
num_data_readers=num_data_readers, num_data_readers=num_data_readers,
match_mlperf=match_mlperf, match_mlperf=match_mlperf,
deterministic=deterministic) deterministic=deterministic,
cache_id=cache_id)
# By limiting the number of workers we guarantee that the worker # By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes. # pool underlying the training generation doesn't starve other processes.
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1 num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
...@@ -473,13 +474,14 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -473,13 +474,14 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# We write to a temp file then atomically rename it to the final file, # We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation # because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file. # async process to read a partially written JSON file.
flagfile_temp = os.path.join(data_dir, rconst.FLAGFILE_TEMP) flagfile_temp = os.path.join(ncf_dataset.cache_paths.cache_root,
rconst.FLAGFILE_TEMP)
tf.logging.info("Preparing flagfile for async data generation in {} ..." tf.logging.info("Preparing flagfile for async data generation in {} ..."
.format(flagfile_temp)) .format(flagfile_temp))
with tf.gfile.Open(flagfile_temp, "w") as f: with tf.gfile.Open(flagfile_temp, "w") as f:
for k, v in six.iteritems(flags_): for k, v in six.iteritems(flags_):
f.write("--{}={}\n".format(k, v)) f.write("--{}={}\n".format(k, v))
flagfile = os.path.join(data_dir, rconst.FLAGFILE) flagfile = os.path.join(ncf_dataset.cache_paths.cache_root, rconst.FLAGFILE)
tf.gfile.Rename(flagfile_temp, flagfile) tf.gfile.Rename(flagfile_temp, flagfile)
tf.logging.info( tf.logging.info(
"Wrote flagfile for async data generation in {}." "Wrote flagfile for async data generation in {}."
...@@ -493,7 +495,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -493,7 +495,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# contention with the main training process. # contention with the main training process.
subproc_env["CUDA_VISIBLE_DEVICES"] = "" subproc_env["CUDA_VISIBLE_DEVICES"] = ""
subproc_args = popen_helper.INVOCATION + [ subproc_args = popen_helper.INVOCATION + [
"--data_dir", data_dir] "--data_dir", data_dir,
"--cache_id", str(ncf_dataset.cache_paths.cache_id)]
tf.logging.info( tf.logging.info(
"Generation subprocess command: {}".format(" ".join(subproc_args))) "Generation subprocess command: {}".format(" ".join(subproc_args)))
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env) proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
......
...@@ -152,7 +152,8 @@ def run_ncf(_): ...@@ -152,7 +152,8 @@ def run_ncf(_):
epochs_per_cycle=FLAGS.epochs_between_evals, epochs_per_cycle=FLAGS.epochs_between_evals,
match_mlperf=FLAGS.ml_perf, match_mlperf=FLAGS.ml_perf,
deterministic=FLAGS.seed is not None, deterministic=FLAGS.seed is not None,
use_subprocess=FLAGS.use_subprocess) use_subprocess=FLAGS.use_subprocess,
cache_id=FLAGS.cache_id)
num_users = ncf_dataset.num_users num_users = ncf_dataset.num_users
num_items = ncf_dataset.num_items num_items = ncf_dataset.num_items
approx_train_steps = int(ncf_dataset.num_train_positives approx_train_steps = int(ncf_dataset.num_train_positives
...@@ -387,6 +388,12 @@ def define_ncf_flags(): ...@@ -387,6 +388,12 @@ def define_ncf_flags():
"subprocess. If set to False, ncf_main.py will assume the async data " "subprocess. If set to False, ncf_main.py will assume the async data "
"generation process has already been started by the user.")) "generation process has already been started by the user."))
flags.DEFINE_integer(name="cache_id", default=None, help=flags_core.help_wrap(
"Use a specified cache_id rather than using a timestamp. This is only "
"needed to synchronize across multiple workers. Generally this flag will "
"not need to be set."
))
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment