Unverified Commit b5a69819 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by GitHub
Browse files

Add flag to use custom training loop for keras NCF model. (#6905)

* Add flag to use custom training loop for keras NCF model.

* Add error check to NCF model for custom training loop + tf1.0.
parent 383c6e30
......@@ -108,7 +108,7 @@ def parse_flags(flags_obj):
flags_obj.clone_model_in_keras_dist_strat,
"epochs_between_evals": FLAGS.epochs_between_evals,
"turn_off_distribution_strategy": FLAGS.turn_off_distribution_strategy,
"tensorflow_v2": int(tf.__version__.split('.')[0]) > 1,
"keras_use_ctl": flags_obj.keras_use_ctl,
"hr_threshold": flags_obj.hr_threshold,
}
......@@ -329,6 +329,11 @@ def define_ncf_flags():
help=flags_core.help_wrap(
'If True, we stop the training when it reaches hr_threshold'))
flags.DEFINE_bool(
name="keras_use_ctl",
default=False,
help=flags_core.help_wrap(
'If True, we use a custom training loop for keras.'))
def convert_to_softmax_logits(logits):
'''Convert the logits returned by the base model to softmax logits.
......
......@@ -104,7 +104,7 @@ def _get_train_and_eval_data(producer, params):
fit.
- The label needs to be extended to be used in the loss fn
"""
if not params["tensorflow_v2"]:
if not params["keras_use_ctl"]:
features.pop(rconst.VALID_POINT_MASK)
labels = tf.expand_dims(labels, -1)
return features, labels
......@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params):
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
if not params["tensorflow_v2"]:
if not params["keras_use_ctl"]:
train_input_dataset = train_input_dataset.repeat(FLAGS.train_epochs)
def preprocess_eval_input(features):
......@@ -125,7 +125,7 @@ def _get_train_and_eval_data(producer, params):
fit.
- The label needs to be extended to be used in the loss fn
"""
if not params["tensorflow_v2"]:
if not params["keras_use_ctl"]:
features.pop(rconst.DUPLICATE_MASK)
labels = tf.zeros_like(features[movielens.USER_COLUMN])
labels = tf.expand_dims(labels, -1)
......@@ -240,6 +240,11 @@ def run_ncf(_):
params = ncf_common.parse_flags(FLAGS)
if params['keras_use_ctl'] and int(tf.__version__.split('.')[0]) == 1:
logging.error(
"Custom training loop only works with tensorflow 2.0 and above.")
return
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
# eval). This carries over that rounding to batch_size as well. This is the
# per device batch size
......@@ -280,7 +285,7 @@ def run_ncf(_):
beta_2=params["beta2"],
epsilon=params["epsilon"])
if params['tensorflow_v2']:
if params['keras_use_ctl']:
loss_object = tf.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM,
from_logits=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment