Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9a0986d1
Commit
9a0986d1
authored
Jun 03, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 03, 2020
Browse files
Fix a bug that ncf_keras model cannot be serialized as JSON.
PiperOrigin-RevId: 314664026
parent
869a4806
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
7 deletions
+22
-7
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+22
-7
No files found.
official/recommendation/ncf_keras_main.py
View file @
9a0986d1
...
@@ -37,21 +37,22 @@ from official.recommendation import movielens
...
@@ -37,21 +37,22 @@ from official.recommendation import movielens
from
official.recommendation
import
ncf_common
from
official.recommendation
import
ncf_common
from
official.recommendation
import
ncf_input_pipeline
from
official.recommendation
import
ncf_input_pipeline
from
official.recommendation
import
neumf_model
from
official.recommendation
import
neumf_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.utils.flags
import
core
as
flags_core
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
metric_fn
(
logits
,
dup_mask
,
params
):
def
metric_fn
(
logits
,
dup_mask
,
match_mlperf
):
dup_mask
=
tf
.
cast
(
dup_mask
,
tf
.
float32
)
dup_mask
=
tf
.
cast
(
dup_mask
,
tf
.
float32
)
logits
=
tf
.
slice
(
logits
,
[
0
,
1
],
[
-
1
,
-
1
])
logits
=
tf
.
slice
(
logits
,
[
0
,
1
],
[
-
1
,
-
1
])
in_top_k
,
_
,
metric_weights
,
_
=
neumf_model
.
compute_top_k_and_ndcg
(
in_top_k
,
_
,
metric_weights
,
_
=
neumf_model
.
compute_top_k_and_ndcg
(
logits
,
logits
,
dup_mask
,
dup_mask
,
params
[
"
match_mlperf
"
]
)
match_mlperf
)
metric_weights
=
tf
.
cast
(
metric_weights
,
tf
.
float32
)
metric_weights
=
tf
.
cast
(
metric_weights
,
tf
.
float32
)
return
in_top_k
,
metric_weights
return
in_top_k
,
metric_weights
...
@@ -59,9 +60,16 @@ def metric_fn(logits, dup_mask, params):
...
@@ -59,9 +60,16 @@ def metric_fn(logits, dup_mask, params):
class
MetricLayer
(
tf
.
keras
.
layers
.
Layer
):
class
MetricLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Custom layer of metrics for NCF model."""
"""Custom layer of metrics for NCF model."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
match_mlperf
):
super
(
MetricLayer
,
self
).
__init__
()
super
(
MetricLayer
,
self
).
__init__
()
self
.
params
=
params
self
.
match_mlperf
=
match_mlperf
def
get_config
(
self
):
return
{
"match_mlperf"
:
self
.
match_mlperf
}
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
logits
,
dup_mask
=
inputs
logits
,
dup_mask
=
inputs
...
@@ -70,7 +78,7 @@ class MetricLayer(tf.keras.layers.Layer):
...
@@ -70,7 +78,7 @@ class MetricLayer(tf.keras.layers.Layer):
hr_sum
=
0.0
hr_sum
=
0.0
hr_count
=
0.0
hr_count
=
0.0
else
:
else
:
metric
,
metric_weights
=
metric_fn
(
logits
,
dup_mask
,
self
.
params
)
metric
,
metric_weights
=
metric_fn
(
logits
,
dup_mask
,
self
.
match_mlperf
)
hr_sum
=
tf
.
reduce_sum
(
metric
*
metric_weights
)
hr_sum
=
tf
.
reduce_sum
(
metric
*
metric_weights
)
hr_count
=
tf
.
reduce_sum
(
metric_weights
)
hr_count
=
tf
.
reduce_sum
(
metric_weights
)
...
@@ -89,6 +97,13 @@ class LossLayer(tf.keras.layers.Layer):
...
@@ -89,6 +97,13 @@ class LossLayer(tf.keras.layers.Layer):
self
.
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
self
.
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
,
reduction
=
"sum"
)
from_logits
=
True
,
reduction
=
"sum"
)
def
get_config
(
self
):
return
{
"loss_normalization_factor"
:
self
.
loss_normalization_factor
}
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
logits
,
labels
,
valid_pt_mask_input
=
inputs
logits
,
labels
,
valid_pt_mask_input
=
inputs
loss
=
self
.
loss
(
loss
=
self
.
loss
(
...
@@ -409,7 +424,7 @@ def run_ncf_custom_training(params,
...
@@ -409,7 +424,7 @@ def run_ncf_custom_training(params,
softmax_logits
=
keras_model
(
features
)
softmax_logits
=
keras_model
(
features
)
in_top_k
,
metric_weights
=
metric_fn
(
softmax_logits
,
in_top_k
,
metric_weights
=
metric_fn
(
softmax_logits
,
features
[
rconst
.
DUPLICATE_MASK
],
features
[
rconst
.
DUPLICATE_MASK
],
params
)
params
[
"match_mlperf"
]
)
hr_sum
=
tf
.
reduce_sum
(
in_top_k
*
metric_weights
)
hr_sum
=
tf
.
reduce_sum
(
in_top_k
*
metric_weights
)
hr_count
=
tf
.
reduce_sum
(
metric_weights
)
hr_count
=
tf
.
reduce_sum
(
metric_weights
)
return
hr_sum
,
hr_count
return
hr_sum
,
hr_count
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment