Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f0a8be5d
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "9f8691c6c83b4b231a2086417a0eeaedd2deeaba"
Commit
f0a8be5d
authored
Jun 03, 2019
by
guptapriya
Committed by
guptapriya
Jun 03, 2019
Browse files
try
#1
to fix CTL
parent
70704b94
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
official/recommendation/ncf_keras_main.py
official/recommendation/ncf_keras_main.py
+4
-6
No files found.
official/recommendation/ncf_keras_main.py
View file @
f0a8be5d
...
...
@@ -68,7 +68,7 @@ class MetricLayer(tf.keras.layers.Layer):
return
inputs
[
0
]
def
_get_train_and_eval_data
(
producer
,
params
):
def
_get_train_and_eval_data
(
producer
,
params
):
"""Returns the datasets for training and evalutating."""
def
preprocess_train_input
(
features
,
labels
):
...
...
@@ -313,8 +313,7 @@ def run_ncf(_):
"""Computes loss and applied gradient per replica."""
features
,
labels
=
inputs
with
tf
.
GradientTape
()
as
tape
:
softmax_logits
=
keras_model
([
features
[
movielens
.
USER_COLUMN
],
features
[
movielens
.
ITEM_COLUMN
]])
softmax_logits
=
keras_model
(
features
)
loss
=
loss_object
(
labels
,
softmax_logits
,
sample_weight
=
features
[
rconst
.
VALID_POINT_MASK
])
loss
*=
(
1.0
/
(
batch_size
*
strategy
.
num_replicas_in_sync
))
...
...
@@ -336,8 +335,7 @@ def run_ncf(_):
def
step_fn
(
inputs
):
"""Computes eval metrics per replica."""
features
,
_
=
inputs
softmax_logits
=
keras_model
([
features
[
movielens
.
USER_COLUMN
],
features
[
movielens
.
ITEM_COLUMN
]])
softmax_logits
=
keras_model
(
features
)
logits
=
tf
.
slice
(
softmax_logits
,
[
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
])
dup_mask
=
features
[
rconst
.
DUPLICATE_MASK
]
in_top_k
,
_
,
metric_weights
,
_
=
neumf_model
.
compute_top_k_and_ndcg
(
...
...
@@ -412,7 +410,7 @@ def run_ncf(_):
train_history
=
history
.
history
train_loss
=
train_history
[
"loss"
][
-
1
]
stats
=
build_stats
(
train_loss
,
eval_results
,
time_callback
)
stats
=
build_stats
(
train_loss
,
eval_results
,
None
)
#,
time_callback)
return
stats
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment