Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bd211e3e
Commit
bd211e3e
authored
Sep 02, 2019
by
Gaurav Jain
Committed by
A. Unique TensorFlower
Sep 02, 2019
Browse files
Avoid importing private ObjectIdentitySet class
PiperOrigin-RevId: 266848625
parent
b9ef963d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
7 deletions
+2
-7
official/bert/model_training_utils.py
official/bert/model_training_utils.py
+1
-3
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+1
-4
No files found.
official/bert/model_training_utils.py
View file @
bd211e3e
...
@@ -23,7 +23,6 @@ import os
...
@@ -23,7 +23,6 @@ import os
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
object_identity
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
...
@@ -243,8 +242,7 @@ def run_customized_training_loop(
...
@@ -243,8 +242,7 @@ def run_customized_training_loop(
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
# De-dupes variables due to keras tracking issues.
# De-dupes variables due to keras tracking issues.
tvars
=
list
(
tvars
=
list
({
id
(
v
):
v
for
v
in
model
.
trainable_variables
}.
values
())
object_identity
.
ObjectIdentitySet
(
model
.
trainable_variables
))
if
use_float16
:
if
use_float16
:
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
grads
=
optimizer
.
get_unscaled_gradients
(
scaled_grads
)
...
...
official/transformer/v2/transformer_main.py
View file @
bd211e3e
...
@@ -30,8 +30,6 @@ from absl import flags
...
@@ -30,8 +30,6 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
object_identity
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
official.transformer
import
compute_bleu
from
official.transformer
import
compute_bleu
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
...
@@ -271,8 +269,7 @@ class TransformerTask(object):
...
@@ -271,8 +269,7 @@ class TransformerTask(object):
scaled_loss
=
loss
/
self
.
distribution_strategy
.
num_replicas_in_sync
scaled_loss
=
loss
/
self
.
distribution_strategy
.
num_replicas_in_sync
# De-dupes variables due to keras tracking issues.
# De-dupes variables due to keras tracking issues.
tvars
=
list
(
tvars
=
list
({
id
(
v
):
v
for
v
in
model
.
trainable_variables
}.
values
())
object_identity
.
ObjectIdentitySet
(
model
.
trainable_variables
))
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
opt
.
apply_gradients
(
zip
(
grads
,
tvars
))
opt
.
apply_gradients
(
zip
(
grads
,
tvars
))
# For reporting, the metric takes the mean of losses.
# For reporting, the metric takes the mean of losses.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment