Commit b8318fd3 authored by Tayo Oguntebi's avatar Tayo Oguntebi Committed by Taylor Robie
Browse files

Merges TPU-TC optimizations into HEAD. (#5635)

* Merges TPU-TC optimizations into HEAD.

* Split a line that went over 80 from a tab.

* Remove trailing whitespace.
parent 0c0860ed
......@@ -54,13 +54,14 @@ from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
def construct_estimator(num_gpus, model_dir, params, batch_size,
def construct_estimator(num_gpus, model_dir, iterations, params, batch_size,
eval_batch_size):
"""Construct either an Estimator or TPUEstimator for NCF.
Args:
num_gpus: The number of gpus (Used to select distribution strategy)
model_dir: The model directory for the estimator
iterations: Estimator iterations
params: The params dict for the estimator
batch_size: The mini-batch size for training.
eval_batch_size: The batch size used during evaluation.
......@@ -79,12 +80,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
tf.Session.reset(tpu_cluster_resolver.get_master())
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=100,
iterations_per_loop=iterations,
num_shards=8)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=model_dir,
save_checkpoints_secs=600,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False),
tpu_config=tpu_config)
......@@ -95,12 +97,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
model_fn=neumf_model.neumf_model_fn,
use_tpu=True,
train_batch_size=batch_size,
eval_batch_size=eval_batch_size,
params=tpu_params,
config=run_config)
eval_estimator = tf.contrib.tpu.TPUEstimator(
model_fn=neumf_model.neumf_model_fn,
use_tpu=False,
use_tpu=True,
train_batch_size=1,
eval_batch_size=eval_batch_size,
params=tpu_params,
......@@ -204,7 +207,8 @@ def run_ncf(_):
}
if FLAGS.use_estimator:
train_estimator, eval_estimator = construct_estimator(
num_gpus=num_gpus, model_dir=FLAGS.model_dir, params=params,
num_gpus=num_gpus, model_dir=FLAGS.model_dir,
iterations=num_train_steps, params=params,
batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size)
else:
runner = model_runner.NcfModelRunner(ncf_dataset, params)
......@@ -231,7 +235,7 @@ def run_ncf(_):
test_id=FLAGS.benchmark_test_id)
pred_input_fn = None
eval_input_fn = None
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
target_reached = False
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP)
......@@ -260,8 +264,8 @@ def run_ncf(_):
tf.gfile.DeleteRecursively(train_record_dir)
tf.logging.info("Beginning evaluation.")
if pred_input_fn is None:
pred_input_fn, _, eval_batch_count = data_preprocessing.make_input_fn(
if eval_input_fn is None:
eval_input_fn, _, eval_batch_count = data_preprocessing.make_input_fn(
ncf_dataset=ncf_dataset, is_training=False)
if eval_batch_count != num_eval_steps:
......@@ -272,7 +276,7 @@ def run_ncf(_):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START,
value=cycle_index)
eval_results = eval_estimator.evaluate(pred_input_fn,
eval_results = eval_estimator.evaluate(eval_input_fn,
steps=num_eval_steps)
tf.logging.info("Evaluation complete.")
else:
......
......@@ -171,6 +171,9 @@ def construct_model(users, items, params):
Raises:
ValueError: if the first model layer is not even.
Returns:
logits: network logits
"""
num_users = params["num_users"]
......@@ -193,7 +196,33 @@ def construct_model(users, items, params):
# Input variables
user_input = tf.keras.layers.Input(tensor=users)
item_input = tf.keras.layers.Input(tensor=items)
batch_size = user_input.get_shape()[0]
if params["use_tpu"]:
with tf.variable_scope("embed_weights", reuse=tf.AUTO_REUSE):
cmb_embedding_user = tf.get_variable(
name="embeddings_mf_user",
shape=[num_users, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_embedding_item = tf.get_variable(
name="embeddings_mf_item",
shape=[num_items, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_user_latent = tf.gather(cmb_embedding_user, user_input)
cmb_item_latent = tf.gather(cmb_embedding_item, item_input)
mlp_user_latent = tf.slice(cmb_user_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_item_latent = tf.slice(cmb_item_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent,
mlp_item_latent])
mf_user_latent = tf.slice(cmb_user_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
mf_item_latent = tf.slice(cmb_item_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
else:
# Initializer for embedding layers
embedding_initializer = "glorot_uniform"
......@@ -227,12 +256,14 @@ def construct_model(users, items, params):
# GMF part
mf_user_latent = mf_embedding_user(user_input)
mf_item_latent = mf_embedding_item(item_input)
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
# MLP part
mlp_user_latent = mlp_embedding_user(user_input)
mlp_item_latent = mlp_embedding_item(item_input)
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
# Concatenation of two latent features
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, mlp_item_latent])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment