Commit bed2745d authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix non strategy case; clean up documentation

parent 9214db1d
...@@ -77,19 +77,21 @@ def _get_train_and_eval_data(producer, params): ...@@ -77,19 +77,21 @@ def _get_train_and_eval_data(producer, params):
def preprocess_train_input(features, labels): def preprocess_train_input(features, labels):
"""Pre-process the training data. """Pre-process the training data.
This is needed because: This is needed because
- Distributed training with keras fit does not support extra inputs. The
current implementation for fit does not use the VALID_POINT_MASK in the
input, which makes it extra, so it needs to be removed when using keras
fit.
- The label needs to be extended to be used in the loss fn - The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for DUPLICATE_MASK in training data.
""" """
labels = tf.expand_dims(labels, -1) labels = tf.expand_dims(labels, -1)
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN]) fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
features[rconst.DUPLICATE_MASK] = fake_dup_mask features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
#return (features,)
return features, labels if params["distribute_strategy"]:
return features
else:
# b/134708104
return (features,)
train_input_fn = producer.make_input_fn(is_training=True) train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map( train_input_dataset = train_input_fn(params).map(
...@@ -98,21 +100,23 @@ def _get_train_and_eval_data(producer, params): ...@@ -98,21 +100,23 @@ def _get_train_and_eval_data(producer, params):
def preprocess_eval_input(features): def preprocess_eval_input(features):
"""Pre-process the eval data. """Pre-process the eval data.
This is needed because: This is needed becasue:
- Distributed training with keras fit does not support extra inputs. The
current implementation for fit does not use the DUPLICATE_MASK in the
input, which makes it extra, so it needs to be removed when using keras
fit.
- The label needs to be extended to be used in the loss fn - The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for VALID_PT_MASK in eval data.
""" """
labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool) labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
labels = tf.expand_dims(labels, -1) labels = tf.expand_dims(labels, -1)
fake_valit_pt_mask = tf.cast( fake_valid_pt_mask = tf.cast(
tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool) tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
features[rconst.VALID_POINT_MASK] = fake_valit_pt_mask features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
#return (features,)
return features, labels if params["distribute_strategy"]:
return features
else:
# b/134708104
return (features,)
eval_input_fn = producer.make_input_fn(is_training=False) eval_input_fn = producer.make_input_fn(is_training=False)
eval_input_dataset = eval_input_fn(params).map( eval_input_dataset = eval_input_fn(params).map(
...@@ -245,11 +249,12 @@ def _get_keras_model(params): ...@@ -245,11 +249,12 @@ def _get_keras_model(params):
def run_ncf(_): def run_ncf(_):
"""Run NCF training and eval with Keras."""
if FLAGS.seed is not None: if FLAGS.seed is not None:
print("Setting tf seed") print("Setting tf seed")
tf.random.set_seed(FLAGS.seed) tf.random.set_seed(FLAGS.seed)
"""Run NCF training and eval with Keras."""
# TODO(seemuch): Support different train and eval batch sizes # TODO(seemuch): Support different train and eval batch sizes
if FLAGS.eval_batch_size != FLAGS.batch_size: if FLAGS.eval_batch_size != FLAGS.batch_size:
logging.warning( logging.warning(
...@@ -261,6 +266,11 @@ def run_ncf(_): ...@@ -261,6 +266,11 @@ def run_ncf(_):
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus)
params["distribute_strategy"] = strategy
if params["keras_use_ctl"] and int(tf.__version__.split(".")[0]) == 1: if params["keras_use_ctl"] and int(tf.__version__.split(".")[0]) == 1:
logging.error( logging.error(
"Custom training loop only works with tensorflow 2.0 and above.") "Custom training loop only works with tensorflow 2.0 and above.")
...@@ -297,9 +307,7 @@ def run_ncf(_): ...@@ -297,9 +307,7 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment