Commit 722d9e57 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clearly demarcate contrib symbols from standard tf symbols by importing them directly.

PiperOrigin-RevId: 285618209
parent e5c71d51
...@@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params):
beta2=params["beta2"], beta2=params["beta2"],
epsilon=params["epsilon"]) epsilon=params["epsilon"])
if params["use_tpu"]: if params["use_tpu"]:
# TODO(seemuch): remove this contrib import optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.MODEL_HP_LOSS_FN, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.MODEL_HP_LOSS_FN,
value=mlperf_helper.TAGS.BCE) value=mlperf_helper.TAGS.BCE)
...@@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor ...@@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
use_tpu_spec) use_tpu_spec)
if use_tpu_spec: if use_tpu_spec:
return tf.contrib.tpu.TPUEstimatorSpec( return tf.estimator.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL, mode=tf.estimator.ModeKeys.EVAL,
loss=cross_entropy, loss=cross_entropy,
eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights])) eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights]))
......
...@@ -283,14 +283,6 @@ def set_up_synthetic_data(): ...@@ -283,14 +283,6 @@ def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method( _monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy) tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
def undo_set_up_synthetic_data(): def undo_set_up_synthetic_data():
...@@ -298,14 +290,6 @@ def undo_set_up_synthetic_data(): ...@@ -298,14 +290,6 @@ def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method( _undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy) tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
def configure_cluster(worker_hosts=None, task_index=-1): def configure_cluster(worker_hosts=None, task_index=-1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment