Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
722d9e57
Commit
722d9e57
authored
Dec 14, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Dec 14, 2019
Browse files
Clearly demarcate contrib symbols from standard tf symbols by importing them directly.
PiperOrigin-RevId: 285618209
parent
e5c71d51
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
19 deletions
+2
-19
official/recommendation/neumf_model.py
official/recommendation/neumf_model.py
+2
-3
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+0
-16
No files found.
official/recommendation/neumf_model.py
View file @
722d9e57
...
@@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params):
...
@@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params):
beta2
=
params
[
"beta2"
],
beta2
=
params
[
"beta2"
],
epsilon
=
params
[
"epsilon"
])
epsilon
=
params
[
"epsilon"
])
if
params
[
"use_tpu"
]:
if
params
[
"use_tpu"
]:
# TODO(seemuch): remove this contrib import
optimizer
=
tf
.
compat
.
v1
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
MODEL_HP_LOSS_FN
,
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
MODEL_HP_LOSS_FN
,
value
=
mlperf_helper
.
TAGS
.
BCE
)
value
=
mlperf_helper
.
TAGS
.
BCE
)
...
@@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
...
@@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
use_tpu_spec
)
use_tpu_spec
)
if
use_tpu_spec
:
if
use_tpu_spec
:
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
loss
=
cross_entropy
,
loss
=
cross_entropy
,
eval_metrics
=
(
metric_fn
,
[
in_top_k
,
ndcg
,
metric_weights
]))
eval_metrics
=
(
metric_fn
,
[
in_top_k
,
ndcg
,
metric_weights
]))
...
...
official/utils/misc/distribution_utils.py
View file @
722d9e57
...
@@ -283,14 +283,6 @@ def set_up_synthetic_data():
...
@@ -283,14 +283,6 @@ def set_up_synthetic_data():
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
else
:
print
(
'Contrib missing: Skip monkey patch tf.contrib.distribute.*'
)
def
undo_set_up_synthetic_data
():
def
undo_set_up_synthetic_data
():
...
@@ -298,14 +290,6 @@ def undo_set_up_synthetic_data():
...
@@ -298,14 +290,6 @@ def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
else
:
print
(
'Contrib missing: Skip remove monkey patch tf.contrib.distribute.*'
)
def
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
):
def
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment