Unverified Commit 8384b05d authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7385)

261786323  by yanhuasun<yanhuasun@google.com>:

    Replace set, dict with ObjectIdentityDict/Set to prepare for eq implementation

--

PiperOrigin-RevId: 261786323
parent 97622ffc
......@@ -23,6 +23,7 @@ import os
from absl import logging
import tensorflow as tf
from tensorflow.python.util import object_identity
from official.utils.misc import distribution_utils
_SUMMARY_TXT = 'training_summary.txt'
......@@ -248,7 +249,8 @@ def run_customized_training_loop(
scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues.
tvars = list(set(model.trainable_variables))
tvars = list(
object_identity.ObjectIdentitySet(model.trainable_variables))
if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment