Commit bd211e3e authored by Gaurav Jain's avatar Gaurav Jain Committed by A. Unique TensorFlower
Browse files

Avoid importing private ObjectIdentitySet class

PiperOrigin-RevId: 266848625
parent b9ef963d
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
...@@ -243,8 +242,7 @@ def run_customized_training_loop( ...@@ -243,8 +242,7 @@ def run_customized_training_loop(
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues. # De-dupes variables due to keras tracking issues.
tvars = list( tvars = list({id(v): v for v in model.trainable_variables}.values())
object_identity.ObjectIdentitySet(model.trainable_variables))
if use_float16: if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars) scaled_grads = tape.gradient(scaled_loss, tvars)
grads = optimizer.get_unscaled_gradients(scaled_grads) grads = optimizer.get_unscaled_gradients(scaled_grads)
......
...@@ -30,8 +30,6 @@ from absl import flags ...@@ -30,8 +30,6 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
...@@ -271,8 +269,7 @@ class TransformerTask(object): ...@@ -271,8 +269,7 @@ class TransformerTask(object):
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
# De-dupes variables due to keras tracking issues. # De-dupes variables due to keras tracking issues.
tvars = list( tvars = list({id(v): v for v in model.trainable_variables}.values())
object_identity.ObjectIdentitySet(model.trainable_variables))
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
opt.apply_gradients(zip(grads, tvars)) opt.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses. # For reporting, the metric takes the mean of losses.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment