Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8bd9aa11
Commit
8bd9aa11
authored
Dec 16, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Dec 16, 2019
Browse files
Make NCF not depend on tf.contrib.
Remove not maintained code path. PiperOrigin-RevId: 285869559
parent
8d9a16ce
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
37 deletions
+12
-37
official/recommendation/data_pipeline.py
official/recommendation/data_pipeline.py
+1
-4
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+0
-11
official/recommendation/ncf_estimator_main.py
official/recommendation/ncf_estimator_main.py
+2
-8
official/recommendation/neumf_model.py
official/recommendation/neumf_model.py
+9
-14
No files found.
official/recommendation/data_pipeline.py
View file @
8bd9aa11
...
@@ -39,6 +39,7 @@ from official.recommendation import constants as rconst
...
@@ -39,6 +39,7 @@ from official.recommendation import constants as rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
movielens
from
official.recommendation
import
popen_helper
from
official.recommendation
import
popen_helper
from
official.recommendation
import
stat_utils
from
official.recommendation
import
stat_utils
from
tensorflow.python.tpu.datasets
import
StreamingFilesDataset
SUMMARY_TEMPLATE
=
"""General:
SUMMARY_TEMPLATE
=
"""General:
...
@@ -286,10 +287,6 @@ class DatasetManager(object):
...
@@ -286,10 +287,6 @@ class DatasetManager(object):
file_pattern
=
os
.
path
.
join
(
file_pattern
=
os
.
path
.
join
(
epoch_data_dir
,
rconst
.
SHARD_TEMPLATE
.
format
(
"*"
))
epoch_data_dir
,
rconst
.
SHARD_TEMPLATE
.
format
(
"*"
))
# TODO(seemuch): remove this contrib import
# pylint: disable=line-too-long
from
tensorflow.contrib.tpu.python.tpu.datasets
import
StreamingFilesDataset
# pylint: enable=line-too-long
dataset
=
StreamingFilesDataset
(
dataset
=
StreamingFilesDataset
(
files
=
file_pattern
,
worker_job
=
popen_helper
.
worker_job
(),
files
=
file_pattern
,
worker_job
=
popen_helper
.
worker_job
(),
num_parallel_reads
=
rconst
.
NUM_FILE_SHARDS
,
num_epochs
=
1
,
num_parallel_reads
=
rconst
.
NUM_FILE_SHARDS
,
num_epochs
=
1
,
...
...
official/recommendation/ncf_common.py
View file @
8bd9aa11
...
@@ -94,7 +94,6 @@ def parse_flags(flags_obj):
...
@@ -94,7 +94,6 @@ def parse_flags(flags_obj):
"beta2"
:
flags_obj
.
beta2
,
"beta2"
:
flags_obj
.
beta2
,
"epsilon"
:
flags_obj
.
epsilon
,
"epsilon"
:
flags_obj
.
epsilon
,
"match_mlperf"
:
flags_obj
.
ml_perf
,
"match_mlperf"
:
flags_obj
.
ml_perf
,
"use_xla_for_gpu"
:
flags_obj
.
use_xla_for_gpu
,
"epochs_between_evals"
:
FLAGS
.
epochs_between_evals
,
"epochs_between_evals"
:
FLAGS
.
epochs_between_evals
,
"keras_use_ctl"
:
flags_obj
.
keras_use_ctl
,
"keras_use_ctl"
:
flags_obj
.
keras_use_ctl
,
"hr_threshold"
:
flags_obj
.
hr_threshold
,
"hr_threshold"
:
flags_obj
.
hr_threshold
,
...
@@ -307,16 +306,6 @@ def define_ncf_flags():
...
@@ -307,16 +306,6 @@ def define_ncf_flags():
return
(
eval_batch_size
is
None
or
return
(
eval_batch_size
is
None
or
int
(
eval_batch_size
)
>
rconst
.
NUM_EVAL_NEGATIVES
)
int
(
eval_batch_size
)
>
rconst
.
NUM_EVAL_NEGATIVES
)
flags
.
DEFINE_bool
(
name
=
"use_xla_for_gpu"
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
"If True, use XLA for the model function. Only works when using a "
"GPU. On TPUs, XLA is always used"
))
xla_message
=
"--use_xla_for_gpu is incompatible with --tpu"
@
flags
.
multi_flags_validator
([
"use_xla_for_gpu"
,
"tpu"
],
message
=
xla_message
)
def
xla_validator
(
flag_dict
):
return
not
flag_dict
[
"use_xla_for_gpu"
]
or
not
flag_dict
[
"tpu"
]
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
name
=
"early_stopping"
,
name
=
"early_stopping"
,
default
=
False
,
default
=
False
,
...
...
official/recommendation/ncf_estimator_main.py
View file @
8bd9aa11
...
@@ -57,25 +57,19 @@ FLAGS = flags.FLAGS
...
@@ -57,25 +57,19 @@ FLAGS = flags.FLAGS
def
construct_estimator
(
model_dir
,
params
):
def
construct_estimator
(
model_dir
,
params
):
"""Construct either an Estimator
or TPUEstimator
for NCF.
"""Construct either an Estimator for NCF.
Args:
Args:
model_dir: The model directory for the estimator
model_dir: The model directory for the estimator
params: The params dict for the estimator
params: The params dict for the estimator
Returns:
Returns:
An Estimator
or TPUEstimator
.
An Estimator.
"""
"""
distribution
=
ncf_common
.
get_v1_distribution_strategy
(
params
)
distribution
=
ncf_common
.
get_v1_distribution_strategy
(
params
)
run_config
=
tf
.
estimator
.
RunConfig
(
train_distribute
=
distribution
,
run_config
=
tf
.
estimator
.
RunConfig
(
train_distribute
=
distribution
,
eval_distribute
=
distribution
)
eval_distribute
=
distribution
)
model_fn
=
neumf_model
.
neumf_model_fn
model_fn
=
neumf_model
.
neumf_model_fn
if
params
[
"use_xla_for_gpu"
]:
# TODO(seemuch): remove the contrib imput
from
tensorflow.contrib.compiler
import
xla
logging
.
info
(
"Using XLA for GPU for training and evaluation."
)
model_fn
=
xla
.
estimator_model_fn
(
model_fn
)
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
params
=
params
)
config
=
run_config
,
params
=
params
)
return
estimator
return
estimator
...
...
official/recommendation/neumf_model.py
View file @
8bd9aa11
...
@@ -93,7 +93,7 @@ def neumf_model_fn(features, labels, mode, params):
...
@@ -93,7 +93,7 @@ def neumf_model_fn(features, labels, mode, params):
duplicate_mask
,
duplicate_mask
,
params
[
"num_neg"
],
params
[
"num_neg"
],
params
[
"match_mlperf"
],
params
[
"match_mlperf"
],
use_tpu_spec
=
params
[
"use_
xla_for_g
pu"
])
use_tpu_spec
=
params
[
"use_
t
pu"
])
elif
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
elif
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
...
@@ -269,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
...
@@ -269,8 +269,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
softmax_logits
,
softmax_logits
,
duplicate_mask
,
duplicate_mask
,
num_training_neg
,
num_training_neg
,
match_mlperf
,
match_mlperf
)
use_tpu_spec
)
if
use_tpu_spec
:
if
use_tpu_spec
:
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
.
estimator
.
tpu
.
TPUEstimatorSpec
(
...
@@ -285,13 +284,13 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
...
@@ -285,13 +284,13 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
)
)
def
compute_eval_loss_and_metrics_helper
(
logits
,
# type: tf.Tensor
def
compute_eval_loss_and_metrics_helper
(
logits
,
# type: tf.Tensor
softmax_logits
,
# type: tf.Tensor
softmax_logits
,
# type: tf.Tensor
duplicate_mask
,
# type: tf.Tensor
duplicate_mask
,
# type: tf.Tensor
num_training_neg
,
# type: int
num_training_neg
,
# type: int
match_mlperf
=
False
,
# type: bool
match_mlperf
=
False
# type: bool
use_tpu_spec
=
False
# type: bool
):
):
"""Model evaluation with HR and NDCG metrics.
"""Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items)
The evaluation protocol is to rank the test interacted item (truth items)
...
@@ -348,10 +347,6 @@ def compute_eval_loss_and_metrics_helper(logits, # type: tf.Tensor
...
@@ -348,10 +347,6 @@ def compute_eval_loss_and_metrics_helper(logits, # type: tf.Tensor
match_mlperf: Use the MLPerf reference convention for computing rank.
match_mlperf: Use the MLPerf reference convention for computing rank.
use_tpu_spec: Should a TPUEstimatorSpec be returned instead of an
EstimatorSpec. Required for TPUs and if XLA is done on a GPU. Despite its
name, TPUEstimatorSpecs work with GPUs
Returns:
Returns:
cross_entropy: the loss
cross_entropy: the loss
metric_fn: the metrics function
metric_fn: the metrics function
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment