"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "8214a9779997ad3606a4830a7b718ba1bcfbf16d"
Unverified Commit 8384b05d authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7385)

261786323  by yanhuasun<yanhuasun@google.com>:

    Replace set, dict with ObjectIdentityDict/Set to prepare for eq implementation

--

PiperOrigin-RevId: 261786323
parent 97622ffc
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
_SUMMARY_TXT = 'training_summary.txt' _SUMMARY_TXT = 'training_summary.txt'
...@@ -248,7 +249,8 @@ def run_customized_training_loop( ...@@ -248,7 +249,8 @@ def run_customized_training_loop(
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues. # De-dupes variables due to keras tracking issues.
tvars = list(set(model.trainable_variables)) tvars = list(
object_identity.ObjectIdentitySet(model.trainable_variables))
if use_float16: if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars) scaled_grads = tape.gradient(scaled_loss, tvars)
grads = optimizer.get_unscaled_gradients(scaled_grads) grads = optimizer.get_unscaled_gradients(scaled_grads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment