Unverified Commit 424fe9f6 authored by Reed's avatar Reed Committed by GitHub
Browse files

Have async process end when all data is written. (#5652)

I've noticed sometimes the async process's pool processes do not die when ncf_main.py ends and kills the async process. This commit fixes the issue.
parent 31ae57eb
...@@ -410,6 +410,7 @@ def _generation_loop(num_workers, # type: int ...@@ -410,6 +410,7 @@ def _generation_loop(num_workers, # type: int
num_items, # type: int num_items, # type: int
num_users, # type: int num_users, # type: int
epochs_per_cycle, # type: int epochs_per_cycle, # type: int
num_cycles, # type: int
train_batch_size, # type: int train_batch_size, # type: int
eval_batch_size, # type: int eval_batch_size, # type: int
deterministic, # type: bool deterministic, # type: bool
...@@ -449,7 +450,7 @@ def _generation_loop(num_workers, # type: int ...@@ -449,7 +450,7 @@ def _generation_loop(num_workers, # type: int
wait_count = 0 wait_count = 0
start_time = time.time() start_time = time.time()
while True: while train_cycle < num_cycles:
ready_epochs = tf.gfile.ListDirectory(cache_paths.train_epoch_dir) ready_epochs = tf.gfile.ListDirectory(cache_paths.train_epoch_dir)
if len(ready_epochs) >= rconst.CYCLES_TO_BUFFER: if len(ready_epochs) >= rconst.CYCLES_TO_BUFFER:
wait_count += 1 wait_count += 1
...@@ -567,6 +568,7 @@ def main(_): ...@@ -567,6 +568,7 @@ def main(_):
num_items=flags.FLAGS.num_items, num_items=flags.FLAGS.num_items,
num_users=flags.FLAGS.num_users, num_users=flags.FLAGS.num_users,
epochs_per_cycle=flags.FLAGS.epochs_per_cycle, epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
num_cycles=flags.FLAGS.num_cycles,
train_batch_size=flags.FLAGS.train_batch_size, train_batch_size=flags.FLAGS.train_batch_size,
eval_batch_size=flags.FLAGS.eval_batch_size, eval_batch_size=flags.FLAGS.eval_batch_size,
deterministic=flags.FLAGS.seed is not None, deterministic=flags.FLAGS.seed is not None,
...@@ -608,6 +610,9 @@ def define_flags(): ...@@ -608,6 +610,9 @@ def define_flags():
flags.DEFINE_integer(name="epochs_per_cycle", default=1, flags.DEFINE_integer(name="epochs_per_cycle", default=1,
help="The number of epochs of training data to produce" help="The number of epochs of training data to produce"
"at a time.") "at a time.")
flags.DEFINE_integer(name="num_cycles", default=None,
help="The number of cycles to produce training data "
"for.")
flags.DEFINE_integer(name="train_batch_size", default=None, flags.DEFINE_integer(name="train_batch_size", default=None,
help="The batch size with which training TFRecords will " help="The batch size with which training TFRecords will "
"be chunked.") "be chunked.")
......
...@@ -394,7 +394,8 @@ def _shutdown(proc): ...@@ -394,7 +394,8 @@ def _shutdown(proc):
try: try:
proc.send_signal(signal.SIGINT) proc.send_signal(signal.SIGINT)
time.sleep(5) time.sleep(5)
if proc.returncode is not None: if proc.poll() is not None:
tf.logging.info("Train data creation subprocess ended")
return # SIGINT was handled successfully within 5 seconds return # SIGINT was handled successfully within 5 seconds
except socket.error: except socket.error:
...@@ -403,6 +404,7 @@ def _shutdown(proc): ...@@ -403,6 +404,7 @@ def _shutdown(proc):
# Otherwise another second of grace period and then force kill the process. # Otherwise another second of grace period and then force kill the process.
time.sleep(1) time.sleep(1)
proc.terminate() proc.terminate()
tf.logging.info("Train data creation subprocess killed")
except: # pylint: disable=broad-except except: # pylint: disable=broad-except
tf.logging.error("Data generation subprocess could not be killed.") tf.logging.error("Data generation subprocess could not be killed.")
...@@ -429,9 +431,10 @@ def write_flagfile(flags_, ncf_dataset): ...@@ -429,9 +431,10 @@ def write_flagfile(flags_, ncf_dataset):
"Wrote flagfile for async data generation in {}.".format(flagfile)) "Wrote flagfile for async data generation in {}.".format(flagfile))
def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers=None, num_neg=4, epochs_per_cycle=1, num_cycles, num_data_readers=None, num_neg=4,
match_mlperf=False, deterministic=False, epochs_per_cycle=1, match_mlperf=False,
use_subprocess=True, cache_id=None): deterministic=False, use_subprocess=True,
cache_id=None):
# type: (...) -> (NCFDataset, typing.Callable) # type: (...) -> (NCFDataset, typing.Callable)
"""Preprocess data and start negative generation subprocess.""" """Preprocess data and start negative generation subprocess."""
...@@ -455,6 +458,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -455,6 +458,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
"num_users": ncf_dataset.num_users, "num_users": ncf_dataset.num_users,
"num_readers": ncf_dataset.num_data_readers, "num_readers": ncf_dataset.num_data_readers,
"epochs_per_cycle": epochs_per_cycle, "epochs_per_cycle": epochs_per_cycle,
"num_cycles": num_cycles,
"train_batch_size": batch_size, "train_batch_size": batch_size,
"eval_batch_size": eval_batch_size, "eval_batch_size": eval_batch_size,
"num_workers": num_workers, "num_workers": num_workers,
...@@ -656,6 +660,16 @@ def make_input_fn( ...@@ -656,6 +660,16 @@ def make_input_fn(
return input_fn, record_dir, batch_count return input_fn, record_dir, batch_count
def _check_subprocess_alive(ncf_dataset, directory):
if (not tf.gfile.Exists(ncf_dataset.cache_paths.subproc_alive) and
not tf.gfile.Exists(directory)):
# The generation subprocess must have been alive at some point, because we
# earlier checked that the subproc_alive file existed.
raise ValueError("Generation subprocess unexpectedly died. Data will not "
"be available; exiting to avoid waiting forever.")
def get_epoch_info(is_training, ncf_dataset): def get_epoch_info(is_training, ncf_dataset):
"""Wait for the epoch input data to be ready and return various info about it. """Wait for the epoch input data to be ready and return various info about it.
...@@ -669,14 +683,10 @@ def get_epoch_info(is_training, ncf_dataset): ...@@ -669,14 +683,10 @@ def get_epoch_info(is_training, ncf_dataset):
template: A string template of the files in `record_dir`. template: A string template of the files in `record_dir`.
`template.format('*')` is a glob that matches all the record files. `template.format('*')` is a glob that matches all the record files.
""" """
if not tf.gfile.Exists(ncf_dataset.cache_paths.subproc_alive):
# The generation subprocess must have been alive at some point, because we
# earlier checked that the subproc_alive file existed.
raise ValueError("Generation subprocess unexpectedly died. Data will not "
"be available; exiting to avoid waiting forever.")
if is_training: if is_training:
train_epoch_dir = ncf_dataset.cache_paths.train_epoch_dir train_epoch_dir = ncf_dataset.cache_paths.train_epoch_dir
_check_subprocess_alive(ncf_dataset, train_epoch_dir)
while not tf.gfile.Exists(train_epoch_dir): while not tf.gfile.Exists(train_epoch_dir):
tf.logging.info("Waiting for {} to exist.".format(train_epoch_dir)) tf.logging.info("Waiting for {} to exist.".format(train_epoch_dir))
time.sleep(1) time.sleep(1)
...@@ -692,6 +702,7 @@ def get_epoch_info(is_training, ncf_dataset): ...@@ -692,6 +702,7 @@ def get_epoch_info(is_training, ncf_dataset):
template = rconst.TRAIN_RECORD_TEMPLATE template = rconst.TRAIN_RECORD_TEMPLATE
else: else:
record_dir = ncf_dataset.cache_paths.eval_data_subdir record_dir = ncf_dataset.cache_paths.eval_data_subdir
_check_subprocess_alive(ncf_dataset, record_dir)
template = rconst.EVAL_RECORD_TEMPLATE template = rconst.EVAL_RECORD_TEMPLATE
ready_file = os.path.join(record_dir, rconst.READY_FILE) ready_file = os.path.join(record_dir, rconst.READY_FILE)
......
...@@ -118,7 +118,7 @@ class BaseTest(tf.test.TestCase): ...@@ -118,7 +118,7 @@ class BaseTest(tf.test.TestCase):
ncf_dataset, _ = data_preprocessing.instantiate_pipeline( ncf_dataset, _ = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, dataset=DATASET, data_dir=self.temp_data_dir,
batch_size=BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE, batch_size=BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE,
num_data_readers=2, num_neg=NUM_NEG) num_cycles=1, num_data_readers=2, num_neg=NUM_NEG)
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
......
...@@ -143,6 +143,7 @@ def run_ncf(_): ...@@ -143,6 +143,7 @@ def run_ncf(_):
num_gpus = flags_core.get_num_gpus(FLAGS) num_gpus = flags_core.get_num_gpus(FLAGS)
batch_size = distribution_utils.per_device_batch_size( batch_size = distribution_utils.per_device_batch_size(
int(FLAGS.batch_size), num_gpus) int(FLAGS.batch_size), num_gpus)
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1 eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1
eval_batch_size = int(FLAGS.eval_batch_size or eval_batch_size = int(FLAGS.eval_batch_size or
...@@ -167,6 +168,7 @@ def run_ncf(_): ...@@ -167,6 +168,7 @@ def run_ncf(_):
eval_batch_size=eval_batch_size, eval_batch_size=eval_batch_size,
num_neg=FLAGS.num_neg, num_neg=FLAGS.num_neg,
epochs_per_cycle=FLAGS.epochs_between_evals, epochs_per_cycle=FLAGS.epochs_between_evals,
num_cycles=total_training_cycle,
match_mlperf=FLAGS.ml_perf, match_mlperf=FLAGS.ml_perf,
deterministic=FLAGS.seed is not None, deterministic=FLAGS.seed is not None,
use_subprocess=FLAGS.use_subprocess, use_subprocess=FLAGS.use_subprocess,
...@@ -237,7 +239,6 @@ def run_ncf(_): ...@@ -237,7 +239,6 @@ def run_ncf(_):
eval_input_fn = None eval_input_fn = None
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
target_reached = False target_reached = False
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP) mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP)
for cycle_index in range(total_training_cycle): for cycle_index in range(total_training_cycle):
......
...@@ -257,10 +257,13 @@ class NcfTest(tf.test.TestCase): ...@@ -257,10 +257,13 @@ class NcfTest(tf.test.TestCase):
flags.FLAGS.ml_perf = True flags.FLAGS.ml_perf = True
ncf_main.main(None) ncf_main.main(None)
@flagsaver.flagsaver(use_estimator=False, use_while_loop=True, @flagsaver.flagsaver(use_estimator=False, **_BASE_END_TO_END_FLAGS)
**_BASE_END_TO_END_FLAGS)
@mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_while_loop(self): def test_end_to_end_while_loop(self):
# We cannot set use_while_loop = True in the flagsaver constructor, because
# if the flagsaver sets it to True before setting use_estimator to False,
# the flag validator will throw an error.
flags.FLAGS.use_while_loop = True
ncf_main.main(None) ncf_main.main(None)
flags.FLAGS.ml_perf = True flags.FLAGS.ml_perf = True
ncf_main.main(None) ncf_main.main(None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment