Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
70704b94
Commit
70704b94
authored
Jun 02, 2019
by
guptapriya
Browse files
Add custom loss and metrics to NCF compile/fit version
parent
dcdc45bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
57 deletions
+69
-57
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+69
-57
No files found.
official/recommendation/ncf_keras_main.py
View file @
70704b94
...
@@ -45,53 +45,30 @@ from official.utils.misc import model_helpers
...
@@ -45,53 +45,30 @@ from official.utils.misc import model_helpers
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
_keras_loss
(
y_true
,
y_pred
):
class
MetricLayer
(
tf
.
keras
.
layers
.
Layer
):
# Here we are using the exact same loss used by the estimator
"""Custom layer of metrics for NCF model."""
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
y_pred
=
y_pred
,
def
__init__
(
self
,
params
):
y_true
=
tf
.
cast
(
y_true
,
tf
.
int32
),
super
(
MetricLayer
,
self
).
__init__
()
from_logits
=
True
)
self
.
params
=
params
return
loss
def
build
(
self
,
input_shape
):
self
.
metric
=
tf
.
keras
.
metrics
.
Mean
(
name
=
rconst
.
HR_METRIC_NAME
)
def
_get_metric_fn
(
params
):
"""Get the metrix fn used by model compile."""
def
call
(
self
,
inputs
):
batch_size
=
params
[
"batch_size"
]
logits
,
dup_mask
=
inputs
dup_mask
=
tf
.
cast
(
dup_mask
,
tf
.
float32
)
def
metric_fn
(
y_true
,
y_pred
):
logits
=
tf
.
slice
(
logits
,
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
])
"""Returns the in_top_k metric."""
in_top_k
,
_
,
metric_weights
,
_
=
neumf_model
.
compute_top_k_and_ndcg
(
softmax_logits
=
y_pred
[
0
,
:]
logits
,
logits
=
tf
.
slice
(
softmax_logits
,
[
0
,
1
],
[
batch_size
,
1
])
dup_mask
,
self
.
params
[
"match_mlperf"
])
# The dup mask should be obtained from input data, but we did not yet find
metric_weights
=
tf
.
cast
(
metric_weights
,
tf
.
float32
)
# a good way of getting it with keras, so we set it to zeros to neglect the
self
.
add_metric
(
self
.
metric
(
in_top_k
,
metric_weights
))
# repetition correction
return
inputs
[
0
]
dup_mask
=
tf
.
zeros
([
batch_size
,
1
])
_
,
_
,
in_top_k
,
_
,
_
=
(
def
_get_train_and_eval_data
(
producer
,
params
):
neumf_model
.
compute_eval_loss_and_metrics_helper
(
logits
,
softmax_logits
,
dup_mask
,
params
[
"num_neg"
],
params
[
"match_mlperf"
],
params
[
"use_xla_for_gpu"
]))
is_training
=
tf
.
keras
.
backend
.
learning_phase
()
if
isinstance
(
is_training
,
int
):
is_training
=
tf
.
constant
(
bool
(
is_training
),
dtype
=
tf
.
bool
)
in_top_k
=
tf
.
cond
(
is_training
,
lambda
:
tf
.
zeros
(
shape
=
in_top_k
.
shape
,
dtype
=
in_top_k
.
dtype
),
lambda
:
in_top_k
)
return
in_top_k
return
metric_fn
def
_get_train_and_eval_data
(
producer
,
params
):
"""Returns the datasets for training and evalutating."""
"""Returns the datasets for training and evalutating."""
def
preprocess_train_input
(
features
,
labels
):
def
preprocess_train_input
(
features
,
labels
):
...
@@ -104,9 +81,10 @@ def _get_train_and_eval_data(producer, params):
...
@@ -104,9 +81,10 @@ def _get_train_and_eval_data(producer, params):
fit.
fit.
- The label needs to be extended to be used in the loss fn
- The label needs to be extended to be used in the loss fn
"""
"""
if
not
params
[
"keras_use_ctl"
]:
features
.
pop
(
rconst
.
VALID_POINT_MASK
)
labels
=
tf
.
expand_dims
(
labels
,
-
1
)
labels
=
tf
.
expand_dims
(
labels
,
-
1
)
fake_dup_mask
=
tf
.
zeros_like
(
features
[
movielens
.
USER_COLUMN
])
features
[
rconst
.
DUPLICATE_MASK
]
=
fake_dup_mask
features
[
rconst
.
TRAIN_LABEL_KEY
]
=
labels
return
features
,
labels
return
features
,
labels
train_input_fn
=
producer
.
make_input_fn
(
is_training
=
True
)
train_input_fn
=
producer
.
make_input_fn
(
is_training
=
True
)
...
@@ -125,10 +103,12 @@ def _get_train_and_eval_data(producer, params):
...
@@ -125,10 +103,12 @@ def _get_train_and_eval_data(producer, params):
fit.
fit.
- The label needs to be extended to be used in the loss fn
- The label needs to be extended to be used in the loss fn
"""
"""
if
not
params
[
"keras_use_ctl"
]:
labels
=
tf
.
cast
(
tf
.
zeros_like
(
features
[
movielens
.
USER_COLUMN
]),
tf
.
bool
)
features
.
pop
(
rconst
.
DUPLICATE_MASK
)
labels
=
tf
.
zeros_like
(
features
[
movielens
.
USER_COLUMN
])
labels
=
tf
.
expand_dims
(
labels
,
-
1
)
labels
=
tf
.
expand_dims
(
labels
,
-
1
)
fake_valit_pt_mask
=
tf
.
cast
(
tf
.
zeros_like
(
features
[
movielens
.
USER_COLUMN
]),
tf
.
bool
)
features
[
rconst
.
VALID_POINT_MASK
]
=
fake_valit_pt_mask
features
[
rconst
.
TRAIN_LABEL_KEY
]
=
labels
return
features
,
labels
return
features
,
labels
eval_input_fn
=
producer
.
make_input_fn
(
is_training
=
False
)
eval_input_fn
=
producer
.
make_input_fn
(
is_training
=
False
)
...
@@ -202,6 +182,24 @@ def _get_keras_model(params):
...
@@ -202,6 +182,24 @@ def _get_keras_model(params):
batch_size
=
params
[
"batches_per_step"
],
batch_size
=
params
[
"batches_per_step"
],
name
=
movielens
.
ITEM_COLUMN
,
name
=
movielens
.
ITEM_COLUMN
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
valid_pt_mask_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
batch_size
,),
batch_size
=
params
[
"batches_per_step"
],
name
=
rconst
.
VALID_POINT_MASK
,
dtype
=
tf
.
bool
)
dup_mask_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
batch_size
,),
batch_size
=
params
[
"batches_per_step"
],
name
=
rconst
.
DUPLICATE_MASK
,
dtype
=
tf
.
int32
)
label_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
batch_size
,
1
),
batch_size
=
params
[
"batches_per_step"
],
name
=
rconst
.
TRAIN_LABEL_KEY
,
dtype
=
tf
.
bool
)
base_model
=
neumf_model
.
construct_model
(
base_model
=
neumf_model
.
construct_model
(
user_input
,
item_input
,
params
,
need_strip
=
True
)
user_input
,
item_input
,
params
,
need_strip
=
True
)
...
@@ -219,10 +217,26 @@ def _get_keras_model(params):
...
@@ -219,10 +217,26 @@ def _get_keras_model(params):
[
zeros
,
logits
],
[
zeros
,
logits
],
axis
=-
1
)
axis
=-
1
)
softmax_logits
=
MetricLayer
(
params
)([
softmax_logits
,
dup_mask_input
])
keras_model
=
tf
.
keras
.
Model
(
keras_model
=
tf
.
keras
.
Model
(
inputs
=
[
user_input
,
item_input
],
inputs
=
[
user_input
,
item_input
,
valid_pt_mask_input
,
dup_mask_input
,
label_input
],
outputs
=
softmax_logits
)
outputs
=
softmax_logits
)
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
,
reduction
=
"sum"
)
keras_model
.
add_loss
(
loss_obj
(
y_true
=
label_input
,
y_pred
=
softmax_logits
,
sample_weight
=
valid_pt_mask_input
)
*
1.0
/
batch_size
)
keras_model
.
summary
()
keras_model
.
summary
()
return
keras_model
return
keras_model
...
@@ -269,7 +283,7 @@ def run_ncf(_):
...
@@ -269,7 +283,7 @@ def run_ncf(_):
time_callback
=
keras_utils
.
TimeHistory
(
batch_size
,
FLAGS
.
log_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
batch_size
,
FLAGS
.
log_steps
)
per_epoch_callback
=
IncrementEpochCallback
(
producer
)
per_epoch_callback
=
IncrementEpochCallback
(
producer
)
callbacks
=
[
per_epoch_callback
,
time_callback
]
callbacks
=
[
per_epoch_callback
]
#
, time_callback]
if
FLAGS
.
early_stopping
:
if
FLAGS
.
early_stopping
:
early_stopping_callback
=
CustomEarlyStopping
(
early_stopping_callback
=
CustomEarlyStopping
(
...
@@ -374,8 +388,6 @@ def run_ncf(_):
...
@@ -374,8 +388,6 @@ def run_ncf(_):
with
distribution_utils
.
get_strategy_scope
(
strategy
):
with
distribution_utils
.
get_strategy_scope
(
strategy
):
keras_model
.
compile
(
keras_model
.
compile
(
loss
=
_keras_loss
,
metrics
=
[
_get_metric_fn
(
params
)],
optimizer
=
optimizer
,
optimizer
=
optimizer
,
cloning
=
params
[
"clone_model_in_keras_dist_strat"
])
cloning
=
params
[
"clone_model_in_keras_dist_strat"
])
...
@@ -385,7 +397,7 @@ def run_ncf(_):
...
@@ -385,7 +397,7 @@ def run_ncf(_):
callbacks
=
callbacks
,
callbacks
=
callbacks
,
validation_data
=
eval_input_dataset
,
validation_data
=
eval_input_dataset
,
validation_steps
=
num_eval_steps
,
validation_steps
=
num_eval_steps
,
verbose
=
2
)
verbose
=
1
)
logging
.
info
(
"Training done. Start evaluating"
)
logging
.
info
(
"Training done. Start evaluating"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment