Unverified Commit 8a15a4df authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Keras-ify NCF TPU embedding lookup (#5641)

* Keras-ify TPU embedding lookup

* delint

* pull get_variable() out of keras lambda

* delint

* move get_variable under variable scope
parent b8318fd3
...@@ -204,24 +204,34 @@ def construct_model(users, items, params): ...@@ -204,24 +204,34 @@ def construct_model(users, items, params):
name="embeddings_mf_user", name="embeddings_mf_user",
shape=[num_users, mf_dim + model_layers[0] // 2], shape=[num_users, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer()) initializer=tf.glorot_uniform_initializer())
cmb_embedding_item = tf.get_variable( cmb_embedding_item = tf.get_variable(
name="embeddings_mf_item", name="embeddings_mf_item",
shape=[num_items, mf_dim + model_layers[0] // 2], shape=[num_items, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer()) initializer=tf.glorot_uniform_initializer())
cmb_user_latent = tf.gather(cmb_embedding_user, user_input) cmb_user_latent = tf.keras.layers.Lambda(lambda ids: tf.gather(
cmb_item_latent = tf.gather(cmb_embedding_item, item_input) cmb_embedding_user, ids))(user_input)
mlp_user_latent = tf.slice(cmb_user_latent, [0, 0], cmb_item_latent = tf.keras.layers.Lambda(lambda ids: tf.gather(
[batch_size, model_layers[0] // 2]) cmb_embedding_item, ids))(item_input)
mlp_item_latent = tf.slice(cmb_item_latent, [0, 0],
[batch_size, model_layers[0] // 2]) mlp_user_latent = tf.keras.layers.Lambda(
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, lambda x: tf.slice(x, [0, 0], [batch_size, model_layers[0] // 2])
mlp_item_latent]) )(cmb_user_latent)
mf_user_latent = tf.slice(cmb_user_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim]) mlp_item_latent = tf.keras.layers.Lambda(
mf_item_latent = tf.slice(cmb_item_latent, [0, model_layers[0] // 2], lambda x: tf.slice(x, [0, 0], [batch_size, model_layers[0] // 2])
[batch_size, mf_dim]) )(cmb_item_latent)
mf_user_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, model_layers[0] // 2], [batch_size, mf_dim])
)(cmb_user_latent)
mf_item_latent = tf.keras.layers.Lambda(
lambda x: tf.slice(x, [0, model_layers[0] // 2], [batch_size, mf_dim])
)(cmb_item_latent)
else: else:
# Initializer for embedding layers # Initializer for embedding layers
embedding_initializer = "glorot_uniform" embedding_initializer = "glorot_uniform"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment