Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
15c1cd77
Commit
15c1cd77
authored
Aug 13, 2019
by
tf-models-copybara-bot
Committed by
Hongkun Yu
Aug 13, 2019
Browse files
Internal change (#7442)
PiperOrigin-RevId: 263204353
parent
161ae74d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
5 deletions
+34
-5
official/recommendation/ncf_input_pipeline.py
official/recommendation/ncf_input_pipeline.py
+18
-1
official/recommendation/ncf_keras_benchmark.py
official/recommendation/ncf_keras_benchmark.py
+6
-0
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+10
-4
No files found.
official/recommendation/ncf_input_pipeline.py
View file @
15c1cd77
...
@@ -117,7 +117,10 @@ def create_dataset_from_data_producer(producer, params):
...
@@ -117,7 +117,10 @@ def create_dataset_from_data_producer(producer, params):
return
train_input_dataset
,
eval_input_dataset
return
train_input_dataset
,
eval_input_dataset
def
create_ncf_input_data
(
params
,
producer
=
None
,
input_meta_data
=
None
):
def
create_ncf_input_data
(
params
,
producer
=
None
,
input_meta_data
=
None
,
strategy
=
None
):
"""Creates NCF training/evaluation dataset.
"""Creates NCF training/evaluation dataset.
Args:
Args:
...
@@ -128,6 +131,9 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
...
@@ -128,6 +131,9 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
input_meta_data: A dictionary of input metadata to be used when reading data
input_meta_data: A dictionary of input metadata to be used when reading data
from tf record files. Must be specified when params["train_input_dataset"]
from tf record files. Must be specified when params["train_input_dataset"]
is specified.
is specified.
strategy: Distribution strategy used for distributed training. If specified,
used to assert that evaluation batch size is correctly a multiple of
total number of devices used.
Returns:
Returns:
(training dataset, evaluation dataset, train steps per epoch,
(training dataset, evaluation dataset, train steps per epoch,
...
@@ -136,6 +142,17 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
...
@@ -136,6 +142,17 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
Raises:
Raises:
ValueError: If data is being generated online for when using TPU's.
ValueError: If data is being generated online for when using TPU's.
"""
"""
# NCF evaluation metric calculation logic assumes that evaluation data
# sample size are in multiples of (1 + number of negative samples in
# evaluation) for each device. As so, evaluation batch size must be a
# multiple of (number of replicas * (1 + number of negative samples)).
num_devices
=
strategy
.
num_replicas_in_sync
if
strategy
else
1
if
(
params
[
"eval_batch_size"
]
%
(
num_devices
*
(
1
+
rconst
.
NUM_EVAL_NEGATIVES
))):
raise
ValueError
(
"Evaluation batch size must be divisible by {} "
"times {}"
.
format
(
num_devices
,
(
1
+
rconst
.
NUM_EVAL_NEGATIVES
)))
if
params
[
"train_dataset_path"
]:
if
params
[
"train_dataset_path"
]:
assert
params
[
"eval_dataset_path"
]
assert
params
[
"eval_dataset_path"
]
...
...
official/recommendation/ncf_keras_benchmark.py
View file @
15c1cd77
...
@@ -199,6 +199,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -199,6 +199,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self
.
_setup
()
self
.
_setup
()
FLAGS
.
early_stopping
=
True
FLAGS
.
early_stopping
=
True
FLAGS
.
num_gpus
=
2
FLAGS
.
num_gpus
=
2
FLAGS
.
eval_batch_size
=
160000
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_2_gpus_ctl_early_stop
(
self
):
def
benchmark_2_gpus_ctl_early_stop
(
self
):
...
@@ -207,6 +208,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -207,6 +208,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
keras_use_ctl
=
True
FLAGS
.
early_stopping
=
True
FLAGS
.
early_stopping
=
True
FLAGS
.
num_gpus
=
2
FLAGS
.
num_gpus
=
2
FLAGS
.
eval_batch_size
=
160000
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
#############################################
#############################################
...
@@ -283,6 +285,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -283,6 +285,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
train_epochs
=
17
FLAGS
.
train_epochs
=
17
FLAGS
.
batch_size
=
1048576
FLAGS
.
batch_size
=
1048576
FLAGS
.
eval_batch_size
=
160000
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
beta1
=
0.25
FLAGS
.
beta1
=
0.25
FLAGS
.
beta2
=
0.5
FLAGS
.
beta2
=
0.5
...
@@ -295,6 +298,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -295,6 +298,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
train_epochs
=
17
FLAGS
.
train_epochs
=
17
FLAGS
.
batch_size
=
1048576
FLAGS
.
batch_size
=
1048576
FLAGS
.
eval_batch_size
=
160000
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
beta1
=
0.25
FLAGS
.
beta1
=
0.25
FLAGS
.
beta2
=
0.5
FLAGS
.
beta2
=
0.5
...
@@ -309,6 +313,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
...
@@ -309,6 +313,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
FLAGS
.
train_epochs
=
17
FLAGS
.
train_epochs
=
17
FLAGS
.
batch_size
=
1048576
FLAGS
.
batch_size
=
1048576
FLAGS
.
eval_batch_size
=
160000
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
learning_rate
=
0.0045
FLAGS
.
beta1
=
0.25
FLAGS
.
beta1
=
0.25
FLAGS
.
beta2
=
0.5
FLAGS
.
beta2
=
0.5
...
@@ -329,6 +334,7 @@ class NCFKerasSynth(NCFKerasBenchmarkBase):
...
@@ -329,6 +334,7 @@ class NCFKerasSynth(NCFKerasBenchmarkBase):
default_flags
[
'num_gpus'
]
=
1
default_flags
[
'num_gpus'
]
=
1
default_flags
[
'train_epochs'
]
=
8
default_flags
[
'train_epochs'
]
=
8
default_flags
[
'batch_size'
]
=
99000
default_flags
[
'batch_size'
]
=
99000
default_flags
[
'eval_batch_size'
]
=
160000
default_flags
[
'learning_rate'
]
=
0.00382059
default_flags
[
'learning_rate'
]
=
0.00382059
default_flags
[
'beta1'
]
=
0.783529
default_flags
[
'beta1'
]
=
0.783529
default_flags
[
'beta2'
]
=
0.909003
default_flags
[
'beta2'
]
=
0.909003
...
...
official/recommendation/ncf_keras_main.py
View file @
15c1cd77
...
@@ -66,10 +66,16 @@ class MetricLayer(tf.keras.layers.Layer):
...
@@ -66,10 +66,16 @@ class MetricLayer(tf.keras.layers.Layer):
self
.
params
=
params
self
.
params
=
params
self
.
metric
=
tf
.
keras
.
metrics
.
Mean
(
name
=
rconst
.
HR_METRIC_NAME
)
self
.
metric
=
tf
.
keras
.
metrics
.
Mean
(
name
=
rconst
.
HR_METRIC_NAME
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
training
=
False
):
logits
,
dup_mask
=
inputs
logits
,
dup_mask
=
inputs
in_top_k
,
metric_weights
=
metric_fn
(
logits
,
dup_mask
,
self
.
params
)
self
.
add_metric
(
self
.
metric
(
in_top_k
,
sample_weight
=
metric_weights
))
if
not
training
:
in_top_k
,
metric_weights
=
metric_fn
(
logits
,
dup_mask
,
self
.
params
)
metric
=
self
.
metric
(
in_top_k
,
sample_weight
=
metric_weights
)
else
:
metric
=
0.0
self
.
add_metric
(
metric
,
name
=
"ncf_metric"
,
aggregation
=
"mean"
)
return
logits
return
logits
...
@@ -249,7 +255,7 @@ def run_ncf(_):
...
@@ -249,7 +255,7 @@ def run_ncf(_):
(
train_input_dataset
,
eval_input_dataset
,
(
train_input_dataset
,
eval_input_dataset
,
num_train_steps
,
num_eval_steps
)
=
\
num_train_steps
,
num_eval_steps
)
=
\
(
ncf_input_pipeline
.
create_ncf_input_data
(
(
ncf_input_pipeline
.
create_ncf_input_data
(
params
,
producer
,
input_meta_data
))
params
,
producer
,
input_meta_data
,
strategy
))
steps_per_epoch
=
None
if
generate_input_online
else
num_train_steps
steps_per_epoch
=
None
if
generate_input_online
else
num_train_steps
with
distribution_utils
.
get_strategy_scope
(
strategy
):
with
distribution_utils
.
get_strategy_scope
(
strategy
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment