"...source/git@developer.sourcefind.cn:OpenDAS/lightx2v.git" did not exist on "a1ebc651ab830a381e8960029145b557990342d6"
Commit 7033c8a2 authored by Priya Gupta's avatar Priya Gupta Committed by guptapriya
Browse files

Add early stopping logic to ncf keras when desired threshold is met. Also...

Add early stopping logic to ncf keras when desired threshold is met. Also change the default batch size to match the tuned hyperparams
parent 7f9db598
...@@ -321,6 +321,12 @@ def define_ncf_flags(): ...@@ -321,6 +321,12 @@ def define_ncf_flags():
'If False, then the experimental code path is used that doesn\'t ' 'If False, then the experimental code path is used that doesn\'t '
"clone models for distribution.")) "clone models for distribution."))
flags.DEFINE_bool(
name="early_stopping",
default=False,
help=flags_core.help_wrap(
'If True, we stop the training when it reaches hr_threshold'))
def convert_to_softmax_logits(logits): def convert_to_softmax_logits(logits):
'''Convert the logits returned by the base model to softmax logits. '''Convert the logits returned by the base model to softmax logits.
......
...@@ -121,11 +121,22 @@ class KerasNCFRealData(KerasNCFBenchmarkBase): ...@@ -121,11 +121,22 @@ class KerasNCFRealData(KerasNCFBenchmarkBase):
self._setup() self._setup()
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_early_stop(self):
self._setup()
FLAGS.early_stopping = True
self._run_and_report_benchmark()
def benchmark_2_gpus(self): def benchmark_2_gpus(self):
self._setup() self._setup()
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_2_gpus_early_stop(self):
self._setup()
FLAGS.early_stopping = True
FLAGS.num_gpus = 2
self._run_and_report_benchmark()
class KerasNCFSyntheticData(KerasNCFBenchmarkBase): class KerasNCFSyntheticData(KerasNCFBenchmarkBase):
"""Benchmark NCF model using synthetic data.""" """Benchmark NCF model using synthetic data."""
......
...@@ -148,6 +148,35 @@ class IncrementEpochCallback(tf.keras.callbacks.Callback): ...@@ -148,6 +148,35 @@ class IncrementEpochCallback(tf.keras.callbacks.Callback):
self._producer.increment_request_epoch() self._producer.increment_request_epoch()
class CustomEarlyStopping(tf.keras.callbacks.Callback):
"""Stop training has reached a desired hit rate."""
def __init__(self, monitor, desired_value):
super(CustomEarlyStopping, self).__init__()
self.monitor = monitor
self.desired = desired_value
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current and current >= self.desired:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
def get_monitor_value(self, logs):
logs = logs or {}
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logging.warning('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s',
self.monitor, ','.join(list(logs.keys())))
return monitor_value
def _get_keras_model(params): def _get_keras_model(params):
"""Constructs and returns the model.""" """Constructs and returns the model."""
batch_size = params['batch_size'] batch_size = params['batch_size']
...@@ -226,6 +255,15 @@ def run_ncf(_): ...@@ -226,6 +255,15 @@ def run_ncf(_):
train_input_dataset = train_input_dataset.batch(batches_per_step) train_input_dataset = train_input_dataset.batch(batches_per_step)
eval_input_dataset = eval_input_dataset.batch(batches_per_step) eval_input_dataset = eval_input_dataset.batch(batches_per_step)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
callbacks = [
IncrementEpochCallback(producer), time_callback]
if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping(
"val_metric_fn", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback)
strategy = ncf_common.get_distribution_strategy(params) strategy = ncf_common.get_distribution_strategy(params)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
...@@ -245,9 +283,7 @@ def run_ncf(_): ...@@ -245,9 +283,7 @@ def run_ncf(_):
history = keras_model.fit(train_input_dataset, history = keras_model.fit(train_input_dataset,
steps_per_epoch=num_train_steps, steps_per_epoch=num_train_steps,
epochs=FLAGS.train_epochs, epochs=FLAGS.train_epochs,
callbacks=[ callbacks=callbacks,
IncrementEpochCallback(producer),
time_callback],
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
verbose=2) verbose=2)
......
...@@ -42,7 +42,7 @@ python "${SCRIPT_DIR}/../datasets/movielens.py" --data_dir ${DATA_DIR} --dataset ...@@ -42,7 +42,7 @@ python "${SCRIPT_DIR}/../datasets/movielens.py" --data_dir ${DATA_DIR} --dataset
if [ "$1" == "keras" ] if [ "$1" == "keras" ]
then then
MAIN_SCRIPT="ncf_keras_main.py" MAIN_SCRIPT="ncf_keras_main.py"
BATCH_SIZE=160000 BATCH_SIZE=99000
DEVICE_FLAG="--num_gpus 1" DEVICE_FLAG="--num_gpus 1"
else else
BATCH_SIZE=98340 BATCH_SIZE=98340
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment