Unverified Commit 4c1d95cc authored by Bruce Fontaine's avatar Bruce Fontaine Committed by GitHub
Browse files

Add a custom training loop for NCF model with TF2.0 (#6899)

* Add a custom training loop for NCF model with TF2.0

* Fix long line in ncf_keras_main.py

* Remove dataset repeat when using custom training loop.
parent df523d91
......@@ -108,6 +108,8 @@ def parse_flags(flags_obj):
flags_obj.clone_model_in_keras_dist_strat,
"epochs_between_evals": FLAGS.epochs_between_evals,
"turn_off_distribution_strategy": FLAGS.turn_off_distribution_strategy,
"tensorflow_v2": int(tf.__version__.split('.')[0]) > 1,
"hr_threshold": flags_obj.hr_threshold,
}
......@@ -245,7 +247,7 @@ def define_ncf_flags():
"optimizer."))
flags.DEFINE_float(
name="hr_threshold", default=None,
name="hr_threshold", default=1.0,
help=flags_core.help_wrap(
"If passed, training will stop when the evaluation metric HR is "
"greater than or equal to hr_threshold. For dataset ml-1m, the "
......
......@@ -98,11 +98,13 @@ def _get_train_and_eval_data(producer, params):
"""Pre-process the training data.
This is needed because:
- Distributed training does not support extra inputs. The current
implementation does not use the VALID_POINT_MASK in the input, which makes
it extra, so it needs to be removed.
- Distributed training with keras fit does not support extra inputs. The
current implementation for fit does not use the VALID_POINT_MASK in the
input, which makes it extra, so it needs to be removed when using keras
fit.
- The label needs to be extended to be used in the loss fn
"""
if not params["tensorflow_v2"]:
features.pop(rconst.VALID_POINT_MASK)
labels = tf.expand_dims(labels, -1)
return features, labels
......@@ -110,17 +112,20 @@ def _get_train_and_eval_data(producer, params):
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
if not params["tensorflow_v2"]:
train_input_dataset = train_input_dataset.repeat(FLAGS.train_epochs)
def preprocess_eval_input(features):
"""Pre-process the eval data.
This is needed because:
- Distributed training does not support extra inputs. The current
implementation does not use the DUPLICATE_MASK in the input, which makes
it extra, so it needs to be removed.
- Distributed training with keras fit does not support extra inputs. The
current implementation for fit does not use the DUPLICATE_MASK in the
input, which makes it extra, so it needs to be removed when using keras
fit.
- The label needs to be extended to be used in the loss fn
"""
if not params["tensorflow_v2"]:
features.pop(rconst.DUPLICATE_MASK)
labels = tf.zeros_like(features[movielens.USER_COLUMN])
labels = tf.expand_dims(labels, -1)
......@@ -234,11 +239,12 @@ def run_ncf(_):
FLAGS.eval_batch_size = FLAGS.batch_size
params = ncf_common.parse_flags(FLAGS)
batch_size = params["batch_size"]
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
# eval). This carries over that rounding to batch_size as well.
# eval). This carries over that rounding to batch_size as well. This is the
# per device batch size
params['batch_size'] = params['eval_batch_size']
batch_size = params["batch_size"]
num_users, num_items, num_train_steps, num_eval_steps, producer = (
ncf_common.get_inputs(params))
......@@ -257,8 +263,8 @@ def run_ncf(_):
eval_input_dataset = eval_input_dataset.batch(batches_per_step)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
callbacks = [
IncrementEpochCallback(producer), time_callback]
per_epoch_callback = IncrementEpochCallback(producer)
callbacks = [per_epoch_callback, time_callback]
if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping(
......@@ -274,6 +280,95 @@ def run_ncf(_):
beta_2=params["beta2"],
epsilon=params["epsilon"])
if params['tensorflow_v2']:
loss_object = tf.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM,
from_logits=True)
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
@tf.function
def train_step():
"""Called once per step to train the model."""
def step_fn(inputs):
"""Computes loss and applied gradient per replica."""
features, labels = inputs
with tf.GradientTape() as tape:
softmax_logits = keras_model([features[movielens.USER_COLUMN],
features[movielens.ITEM_COLUMN]])
loss = loss_object(labels, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))
grads = tape.gradient(loss, keras_model.trainable_variables)
optimizer.apply_gradients(list(zip(grads,
keras_model.trainable_variables)))
return loss
per_replica_losses = strategy.experimental_run(step_fn,
train_input_iterator)
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
@tf.function
def eval_step():
"""Called once per eval step to compute eval metrics."""
def step_fn(inputs):
"""Computes eval metrics per replica."""
features, _ = inputs
softmax_logits = keras_model([features[movielens.USER_COLUMN],
features[movielens.ITEM_COLUMN]])
logits = tf.slice(softmax_logits, [0, 0, 1], [-1, -1, -1])
dup_mask = features[rconst.DUPLICATE_MASK]
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits,
dup_mask,
params["match_mlperf"])
metric_weights = tf.cast(metric_weights, tf.float32)
hr_sum = tf.reduce_sum(in_top_k*metric_weights)
hr_count = tf.reduce_sum(metric_weights)
return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run(step_fn, eval_input_iterator))
hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
return hr_sum, hr_count
time_callback.on_train_begin()
for epoch in range(FLAGS.train_epochs):
per_epoch_callback.on_epoch_begin(epoch)
train_input_iterator.initialize()
train_loss = 0
for step in range(num_train_steps):
time_callback.on_batch_begin(step+epoch*num_train_steps)
train_loss += train_step()
time_callback.on_batch_end(step+epoch*num_train_steps)
logging.info("Done training epoch {}, epoch loss={}.".format(
epoch+1, train_loss/num_train_steps))
eval_input_iterator.initialize()
hr_sum = 0
hr_count = 0
for _ in range(num_eval_steps):
step_hr_sum, step_hr_count = eval_step()
hr_sum += step_hr_sum
hr_count += step_hr_count
logging.info("Done eval epoch {}, hr={}.".format(epoch+1,
hr_sum/hr_count))
if (FLAGS.early_stopping and
float(hr_sum/hr_count) > params["hr_threshold"]):
break
time_callback.on_train_end()
eval_results = [None, hr_sum/hr_count]
else:
with distribution_utils.get_strategy_scope(strategy):
keras_model.compile(
loss=_keras_loss,
metrics=[_get_metric_fn(params)],
......@@ -297,16 +392,19 @@ def run_ncf(_):
logging.info("Keras evaluation is done.")
stats = build_stats(history, eval_results, time_callback)
if history and history.history:
train_history = history.history
train_loss = train_history['loss'][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def build_stats(history, eval_result, time_callback):
def build_stats(loss, eval_result, time_callback):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
loss: The final loss at training time.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback likely used during keras.fit.
......@@ -314,9 +412,8 @@ def build_stats(history, eval_result, time_callback):
Dictionary of normalized results.
"""
stats = {}
if history and history.history:
train_history = history.history
stats['loss'] = train_history['loss'][-1]
if loss:
stats['loss'] = loss
if eval_result:
stats['eval_loss'] = eval_result[0]
......
......@@ -427,8 +427,9 @@ def compute_top_k_and_ndcg(logits, # type: tf.Tensor
(num_users_in_batch, (rconst.NUM_EVAL_NEGATIVES + 1)).
"""
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.cast(
tf.reshape(duplicate_mask, (-1, rconst.NUM_EVAL_NEGATIVES + 1)),
tf.float32)
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment