Commit f6f04066 authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix ctl case; add check for 2.0

parent 8f44de85
......@@ -324,3 +324,8 @@ def convert_to_softmax_logits(logits):
'''
softmax_logits = tf.concat([logits * 0, logits], axis=1)
return softmax_logits
def is_tf_v2():
"""Returns whether it is v2."""
from tensorflow.python import tf2 as tf2_internal
return tf2_internal.enabled()
......@@ -87,7 +87,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"]:
if params["distribute_strategy"] or not ncf_common.is_tf_v2():
return features
else:
# b/134708104
......@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"]:
if params["distribute_strategy"] or not ncf_common.is_tf_v2():
return features
else:
# b/134708104
......@@ -271,9 +271,10 @@ def run_ncf(_):
num_gpus=FLAGS.num_gpus)
params["distribute_strategy"] = strategy
if params["keras_use_ctl"] and int(tf.__version__.split(".")[0]) == 1:
if (params["keras_use_ctl"] and (
not ncf_common.is_tf_v2() or strategy is None)):
logging.error(
"Custom training loop only works with tensorflow 2.0 and above.")
"Custom training loop only works with tensorflow 2.0 and dist strat.")
return
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
......@@ -326,11 +327,11 @@ def run_ncf(_):
@tf.function
def train_step():
"""Called once per step to train the model."""
def step_fn(inputs):
def step_fn(features):
"""Computes loss and applied gradient per replica."""
features, labels = inputs
with tf.GradientTape() as tape:
softmax_logits = keras_model(features)
labels = features[rconst.TRAIN_LABEL_KEY]
loss = loss_object(labels, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))
......@@ -349,9 +350,8 @@ def run_ncf(_):
@tf.function
def eval_step():
"""Called once per eval step to compute eval metrics."""
def step_fn(inputs):
def step_fn(features):
"""Computes eval metrics per replica."""
features, _ = inputs
softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn(
softmax_logits, features[rconst.DUPLICATE_MASK], params)
......
......@@ -220,7 +220,7 @@ class NcfTest(tf.test.TestCase):
['-keras_use_ctl', 'True'])
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_1_gpu_dist_strat(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment