Commit d7aa51b4 authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix model by making inputs a dict

parent 95220449
...@@ -225,12 +225,12 @@ def _get_keras_model(params): ...@@ -225,12 +225,12 @@ def _get_keras_model(params):
softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input]) softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input])
keras_model = tf.keras.Model( keras_model = tf.keras.Model(
inputs=[ inputs={
user_input, movielens.USER_COLUMN: user_input,
item_input, movielens.ITEM_COLUMN: item_input,
valid_pt_mask_input, rconst.VALID_POINT_MASK: valid_pt_mask_input,
dup_mask_input, rconst.DUPLICATE_MASK: dup_mask_input,
label_input], rconst.TRAIN_LABEL_KEY: label_input},
outputs=softmax_logits) outputs=softmax_logits)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy( loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment