Commit b8318fd3 authored by Tayo Oguntebi's avatar Tayo Oguntebi Committed by Taylor Robie
Browse files

Merges TPU-TC optimizations into HEAD. (#5635)

* Merges TPU-TC optimizations into HEAD.

* Split a line that went over 80 from a tab.

* Remove trailing whitespace.
parent 0c0860ed
...@@ -54,13 +54,14 @@ from official.utils.misc import model_helpers ...@@ -54,13 +54,14 @@ from official.utils.misc import model_helpers
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def construct_estimator(num_gpus, model_dir, params, batch_size, def construct_estimator(num_gpus, model_dir, iterations, params, batch_size,
eval_batch_size): eval_batch_size):
"""Construct either an Estimator or TPUEstimator for NCF. """Construct either an Estimator or TPUEstimator for NCF.
Args: Args:
num_gpus: The number of gpus (Used to select distribution strategy) num_gpus: The number of gpus (Used to select distribution strategy)
model_dir: The model directory for the estimator model_dir: The model directory for the estimator
iterations: Estimator iterations
params: The params dict for the estimator params: The params dict for the estimator
batch_size: The mini-batch size for training. batch_size: The mini-batch size for training.
eval_batch_size: The batch size used during evaluation. eval_batch_size: The batch size used during evaluation.
...@@ -79,12 +80,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size, ...@@ -79,12 +80,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
tf.Session.reset(tpu_cluster_resolver.get_master()) tf.Session.reset(tpu_cluster_resolver.get_master())
tpu_config = tf.contrib.tpu.TPUConfig( tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=100, iterations_per_loop=iterations,
num_shards=8) num_shards=8)
run_config = tf.contrib.tpu.RunConfig( run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver, cluster=tpu_cluster_resolver,
model_dir=model_dir, model_dir=model_dir,
save_checkpoints_secs=600,
session_config=tf.ConfigProto( session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False), allow_soft_placement=True, log_device_placement=False),
tpu_config=tpu_config) tpu_config=tpu_config)
...@@ -95,12 +97,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size, ...@@ -95,12 +97,13 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
model_fn=neumf_model.neumf_model_fn, model_fn=neumf_model.neumf_model_fn,
use_tpu=True, use_tpu=True,
train_batch_size=batch_size, train_batch_size=batch_size,
eval_batch_size=eval_batch_size,
params=tpu_params, params=tpu_params,
config=run_config) config=run_config)
eval_estimator = tf.contrib.tpu.TPUEstimator( eval_estimator = tf.contrib.tpu.TPUEstimator(
model_fn=neumf_model.neumf_model_fn, model_fn=neumf_model.neumf_model_fn,
use_tpu=False, use_tpu=True,
train_batch_size=1, train_batch_size=1,
eval_batch_size=eval_batch_size, eval_batch_size=eval_batch_size,
params=tpu_params, params=tpu_params,
...@@ -204,7 +207,8 @@ def run_ncf(_): ...@@ -204,7 +207,8 @@ def run_ncf(_):
} }
if FLAGS.use_estimator: if FLAGS.use_estimator:
train_estimator, eval_estimator = construct_estimator( train_estimator, eval_estimator = construct_estimator(
num_gpus=num_gpus, model_dir=FLAGS.model_dir, params=params, num_gpus=num_gpus, model_dir=FLAGS.model_dir,
iterations=num_train_steps, params=params,
batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size) batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size)
else: else:
runner = model_runner.NcfModelRunner(ncf_dataset, params) runner = model_runner.NcfModelRunner(ncf_dataset, params)
...@@ -231,7 +235,7 @@ def run_ncf(_): ...@@ -231,7 +235,7 @@ def run_ncf(_):
test_id=FLAGS.benchmark_test_id) test_id=FLAGS.benchmark_test_id)
pred_input_fn = None eval_input_fn = None
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
target_reached = False target_reached = False
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP) mlperf_helper.ncf_print(key=mlperf_helper.TAGS.TRAIN_LOOP)
...@@ -260,8 +264,8 @@ def run_ncf(_): ...@@ -260,8 +264,8 @@ def run_ncf(_):
tf.gfile.DeleteRecursively(train_record_dir) tf.gfile.DeleteRecursively(train_record_dir)
tf.logging.info("Beginning evaluation.") tf.logging.info("Beginning evaluation.")
if pred_input_fn is None: if eval_input_fn is None:
pred_input_fn, _, eval_batch_count = data_preprocessing.make_input_fn( eval_input_fn, _, eval_batch_count = data_preprocessing.make_input_fn(
ncf_dataset=ncf_dataset, is_training=False) ncf_dataset=ncf_dataset, is_training=False)
if eval_batch_count != num_eval_steps: if eval_batch_count != num_eval_steps:
...@@ -272,7 +276,7 @@ def run_ncf(_): ...@@ -272,7 +276,7 @@ def run_ncf(_):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_START,
value=cycle_index) value=cycle_index)
eval_results = eval_estimator.evaluate(pred_input_fn, eval_results = eval_estimator.evaluate(eval_input_fn,
steps=num_eval_steps) steps=num_eval_steps)
tf.logging.info("Evaluation complete.") tf.logging.info("Evaluation complete.")
else: else:
......
...@@ -171,6 +171,9 @@ def construct_model(users, items, params): ...@@ -171,6 +171,9 @@ def construct_model(users, items, params):
Raises: Raises:
ValueError: if the first model layer is not even. ValueError: if the first model layer is not even.
Returns:
logits: network logits
""" """
num_users = params["num_users"] num_users = params["num_users"]
...@@ -193,7 +196,33 @@ def construct_model(users, items, params): ...@@ -193,7 +196,33 @@ def construct_model(users, items, params):
# Input variables # Input variables
user_input = tf.keras.layers.Input(tensor=users) user_input = tf.keras.layers.Input(tensor=users)
item_input = tf.keras.layers.Input(tensor=items) item_input = tf.keras.layers.Input(tensor=items)
batch_size = user_input.get_shape()[0]
if params["use_tpu"]:
with tf.variable_scope("embed_weights", reuse=tf.AUTO_REUSE):
cmb_embedding_user = tf.get_variable(
name="embeddings_mf_user",
shape=[num_users, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_embedding_item = tf.get_variable(
name="embeddings_mf_item",
shape=[num_items, mf_dim + model_layers[0] // 2],
initializer=tf.glorot_uniform_initializer())
cmb_user_latent = tf.gather(cmb_embedding_user, user_input)
cmb_item_latent = tf.gather(cmb_embedding_item, item_input)
mlp_user_latent = tf.slice(cmb_user_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_item_latent = tf.slice(cmb_item_latent, [0, 0],
[batch_size, model_layers[0] // 2])
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent,
mlp_item_latent])
mf_user_latent = tf.slice(cmb_user_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
mf_item_latent = tf.slice(cmb_item_latent, [0, model_layers[0] // 2],
[batch_size, mf_dim])
else:
# Initializer for embedding layers # Initializer for embedding layers
embedding_initializer = "glorot_uniform" embedding_initializer = "glorot_uniform"
...@@ -227,12 +256,14 @@ def construct_model(users, items, params): ...@@ -227,12 +256,14 @@ def construct_model(users, items, params):
# GMF part # GMF part
mf_user_latent = mf_embedding_user(user_input) mf_user_latent = mf_embedding_user(user_input)
mf_item_latent = mf_embedding_item(item_input) mf_item_latent = mf_embedding_item(item_input)
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
# MLP part # MLP part
mlp_user_latent = mlp_embedding_user(user_input) mlp_user_latent = mlp_embedding_user(user_input)
mlp_item_latent = mlp_embedding_item(item_input) mlp_item_latent = mlp_embedding_item(item_input)
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
# Concatenation of two latent features # Concatenation of two latent features
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, mlp_item_latent]) mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, mlp_item_latent])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment