Commit f6f04066 authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix ctl case; add check for 2.0

parent 8f44de85
...@@ -324,3 +324,8 @@ def convert_to_softmax_logits(logits): ...@@ -324,3 +324,8 @@ def convert_to_softmax_logits(logits):
''' '''
softmax_logits = tf.concat([logits * 0, logits], axis=1) softmax_logits = tf.concat([logits * 0, logits], axis=1)
return softmax_logits return softmax_logits
def is_tf_v2():
"""Returns whether it is v2."""
from tensorflow.python import tf2 as tf2_internal
return tf2_internal.enabled()
...@@ -87,7 +87,7 @@ def _get_train_and_eval_data(producer, params): ...@@ -87,7 +87,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.DUPLICATE_MASK] = fake_dup_mask features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"]: if params["distribute_strategy"] or not ncf_common.is_tf_v2():
return features return features
else: else:
# b/134708104 # b/134708104
...@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params): ...@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"]: if params["distribute_strategy"] or not ncf_common.is_tf_v2():
return features return features
else: else:
# b/134708104 # b/134708104
...@@ -271,9 +271,10 @@ def run_ncf(_): ...@@ -271,9 +271,10 @@ def run_ncf(_):
num_gpus=FLAGS.num_gpus) num_gpus=FLAGS.num_gpus)
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
if params["keras_use_ctl"] and int(tf.__version__.split(".")[0]) == 1: if (params["keras_use_ctl"] and (
not ncf_common.is_tf_v2() or strategy is None)):
logging.error( logging.error(
"Custom training loop only works with tensorflow 2.0 and above.") "Custom training loop only works with tensorflow 2.0 and dist strat.")
return return
# ncf_common rounds eval_batch_size (this is needed due to a reshape during # ncf_common rounds eval_batch_size (this is needed due to a reshape during
...@@ -326,11 +327,11 @@ def run_ncf(_): ...@@ -326,11 +327,11 @@ def run_ncf(_):
@tf.function @tf.function
def train_step(): def train_step():
"""Called once per step to train the model.""" """Called once per step to train the model."""
def step_fn(inputs): def step_fn(features):
"""Computes loss and applied gradient per replica.""" """Computes loss and applied gradient per replica."""
features, labels = inputs
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
softmax_logits = keras_model(features) softmax_logits = keras_model(features)
labels = features[rconst.TRAIN_LABEL_KEY]
loss = loss_object(labels, softmax_logits, loss = loss_object(labels, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK]) sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync)) loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))
...@@ -349,9 +350,8 @@ def run_ncf(_): ...@@ -349,9 +350,8 @@ def run_ncf(_):
@tf.function @tf.function
def eval_step(): def eval_step():
"""Called once per eval step to compute eval metrics.""" """Called once per eval step to compute eval metrics."""
def step_fn(inputs): def step_fn(features):
"""Computes eval metrics per replica.""" """Computes eval metrics per replica."""
features, _ = inputs
softmax_logits = keras_model(features) softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn( in_top_k, metric_weights = metric_fn(
softmax_logits, features[rconst.DUPLICATE_MASK], params) softmax_logits, features[rconst.DUPLICATE_MASK], params)
......
...@@ -220,7 +220,7 @@ class NcfTest(tf.test.TestCase): ...@@ -220,7 +220,7 @@ class NcfTest(tf.test.TestCase):
['-keras_use_ctl', 'True']) ['-keras_use_ctl', 'True'])
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_1_gpu_dist_strat(self): def test_end_to_end_keras_1_gpu_dist_strat(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment