Unverified Commit 2eeb85fe authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

First pass at a TPU loop for Transformer (#4296)

* port changes from previous branch now that transformer util changes are in master

fix incorrect count

correct (hopefully) treatment of batch_size

set eval_metrics to a dummy function for now

add some comments

start bringing metrics to transformer TPU

resolve logits shape

metrics are now working except for tf.py_func metrics

increase batch_size for tpu, and create summary host call

fix host call

reduce tpu default batch size

further tune batch sizes

add minibatch loss to summary

handle case of single_iteration_train_steps > number points in an epoch

begin to incorporate hooks

add sleep workarounds

disable hooks altogether

generalize host call function and move to newly created tpu utils module

remove all traces of params as an object

switch from  to

address some PR comments, and change the number of data points.

minor tweaks

add tpu dry run for testing, and use matmul for TPU embedding

infeed/outfeed queue issue is fixed. Sleeps are no longer necessary

add some documentation.

cleanup and address PR comments

delint

add accelerator __init__

fix embedding

missed PR comment

address PR comments

fix validator bug

rewrite cloud storage validator, and add oauth dependency to requirements.txt

* delint
parent bd56a06d
......@@ -2,4 +2,6 @@ numpy
pandas
psutil>=5.4.3
py-cpuinfo>=3.3.0
google-api-python-client>=1.6.7
google-cloud-bigquery>=0.31.0
oauth2client>=4.1.2
......@@ -19,6 +19,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
* [Compute official BLEU score](#compute-official-bleu-score)
* [TPU](#tpu)
* [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition)
* [Model Estimator](#model-estimator)
......@@ -200,6 +201,10 @@ big | 28.9
* `--reference`: Path to file containing reference translations.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
5. ### TPU
TPU support for this version of Transformer is experimental. Currently it is present for
demonstration purposes only, but will be optimized in the coming weeks.
## Implementation overview
A brief look at each component in the code:
......
......@@ -21,15 +21,31 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
from official.utils.accelerator import tpu as tpu_utils
class EmbeddingSharedWeights(tf.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size):
def __init__(self, vocab_size, hidden_size, method="gather"):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
method: Strategy for performing embedding lookup. "gather" uses tf.gather
which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul"
one-hot encodes the indicies and formulates the embedding as a sparse
matrix multiplication. The matmul formulation is wasteful as it does
extra work, however matrix multiplication is very fast on TPUs which
makes "matmul" considerably faster than "gather" on TPUs.
"""
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
if method not in ("gather", "matmul"):
raise ValueError("method {} must be 'gather' or 'matmul'".format(method))
self.method = method
def build(self, _):
with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
......@@ -53,19 +69,25 @@ class EmbeddingSharedWeights(tf.layers.Layer):
locations of the padding tokens in x.
"""
with tf.name_scope("embedding"):
# Create binary mask of size [batch_size, length]
mask = tf.to_float(tf.not_equal(x, 0))
if self.method == "gather":
embeddings = tf.gather(self.shared_weights, x)
else: # matmul
embeddings = tpu_utils.embedding_matmul(
embedding_table=self.shared_weights,
values=tf.cast(x, dtype=tf.int32),
mask=mask
)
embeddings *= tf.expand_dims(mask, -1)
# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5
# Create binary array of size [batch_size, length]
# where 1 = padding, 0 = not padding
padding = model_utils.get_padding(x)
# Set all padding embedding values to 0
embeddings *= tf.expand_dims(1 - padding, -1)
return embeddings
def linear(self, x):
"""Computes logits by running x through a linear layer.
......
......@@ -24,12 +24,13 @@ import tensorflow as tf
class FeedFowardNetwork(tf.layers.Layer):
"""Fully connected feedforward network."""
def __init__(self, hidden_size, filter_size, relu_dropout, train):
def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad):
super(FeedFowardNetwork, self).__init__()
self.hidden_size = hidden_size
self.filter_size = filter_size
self.relu_dropout = relu_dropout
self.train = train
self.allow_pad = allow_pad
self.filter_dense_layer = tf.layers.Dense(
filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
......@@ -42,13 +43,16 @@ class FeedFowardNetwork(tf.layers.Layer):
Args:
x: tensor with shape [batch_size, length, hidden_size]
padding: (optional) If set, the padding values are temporarily removed
from x. The padding values are placed back in the output tensor in the
same locations. shape [batch_size, length]
from x (provided self.allow_pad is set). The padding values are placed
back in the output tensor in the same locations.
shape [batch_size, length]
Returns:
Output of the feedforward network.
tensor with shape [batch_size, length, hidden_size]
"""
padding = None if not self.allow_pad else padding
# Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
......
......@@ -15,45 +15,64 @@
"""Defines Transformer model parameters."""
class TransformerBaseParams(object):
"""Parameters for the base Transformer model."""
BASE_PARAMS = dict(
# Input params
batch_size = 2048 # Maximum number of tokens per batch of examples.
max_length = 256 # Maximum number of tokens per example.
default_batch_size=2048, # Maximum number of tokens per batch of examples.
default_batch_size_tpu=32768,
max_length=256, # Maximum number of tokens per example.
# Model params
initializer_gain = 1.0 # Used in trainable variable initialization.
vocab_size = 33708 # Number of tokens defined in the vocabulary file.
hidden_size = 512 # Model dimension in the hidden layers.
num_hidden_layers = 6 # Number of layers in the encoder and decoder stacks.
num_heads = 8 # Number of heads to use in multi-headed attention.
filter_size = 2048 # Inner layer dimensionality in the feedforward network.
initializer_gain=1.0, # Used in trainable variable initialization.
vocab_size=33708, # Number of tokens defined in the vocabulary file.
hidden_size=512, # Model dimension in the hidden layers.
num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
num_heads=8, # Number of heads to use in multi-headed attention.
filter_size=2048, # Inner layer dimension in the feedforward network.
# Dropout values (only used when training)
layer_postprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
layer_postprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
# Training params
label_smoothing = 0.1
learning_rate = 2.0
learning_rate_decay_rate = 1.0
learning_rate_warmup_steps = 16000
label_smoothing=0.1,
learning_rate=2.0,
learning_rate_decay_rate=1.0,
learning_rate_warmup_steps=16000,
# Optimizer params
optimizer_adam_beta1 = 0.9
optimizer_adam_beta2 = 0.997
optimizer_adam_epsilon = 1e-09
optimizer_adam_beta1=0.9,
optimizer_adam_beta2=0.997,
optimizer_adam_epsilon=1e-09,
# Default prediction params
extra_decode_length = 50
beam_size = 4
alpha = 0.6 # used to calculate length normalization in beam search
extra_decode_length=50,
beam_size=4,
alpha=0.6, # used to calculate length normalization in beam search
# TPU specific parameters
use_tpu=False,
static_batch=False,
allow_ffn_pad=True,
)
class TransformerBigParams(TransformerBaseParams):
"""Parameters for the big Transformer model."""
batch_size = 4096
hidden_size = 1024
filter_size = 4096
num_heads = 16
BIG_PARAMS = dict(BASE_PARAMS)
BIG_PARAMS.update(dict(
default_batch_size=4096,
# default batch size is smaller than for BASE_PARAMS due to memory limits.
default_batch_size_tpu=16384,
hidden_size=1024,
filter_size=4096,
num_heads=16,
))
TINY_PARAMS = dict(BASE_PARAMS)
TINY_PARAMS.update(dict(
default_batch_size=1024,
default_batch_size_tpu=1024,
hidden_size=32,
num_heads=4,
filter_size=256,
))
......@@ -57,7 +57,8 @@ class Transformer(object):
self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params.vocab_size, params.hidden_size)
params["vocab_size"], params["hidden_size"],
method="matmul" if params["tpu"] else "gather")
self.encoder_stack = EncoderStack(params, train)
self.decoder_stack = DecoderStack(params, train)
......@@ -79,7 +80,7 @@ class Transformer(object):
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
initializer = tf.variance_scaling_initializer(
self.params.initializer_gain, mode="fan_avg", distribution="uniform")
self.params["initializer_gain"], mode="fan_avg", distribution="uniform")
with tf.variable_scope("Transformer", initializer=initializer):
# Calculate attention bias for encoder self-attention and decoder
# multi-headed attention layers.
......@@ -116,12 +117,12 @@ class Transformer(object):
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params.hidden_size)
length, self.params["hidden_size"])
encoder_inputs = embedded_inputs + pos_encoding
if self.train:
encoder_inputs = tf.nn.dropout(
encoder_inputs, 1 - self.params.layer_postprocess_dropout)
encoder_inputs, 1 - self.params["layer_postprocess_dropout"])
return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)
......@@ -149,10 +150,10 @@ class Transformer(object):
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding(
length, self.params.hidden_size)
length, self.params["hidden_size"])
if self.train:
decoder_inputs = tf.nn.dropout(
decoder_inputs, 1 - self.params.layer_postprocess_dropout)
decoder_inputs, 1 - self.params["layer_postprocess_dropout"])
# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
......@@ -167,7 +168,7 @@ class Transformer(object):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params.hidden_size)
max_decode_length + 1, self.params["hidden_size"])
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length)
......@@ -206,7 +207,7 @@ class Transformer(object):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params.extra_decode_length
max_decode_length = input_length + self.params["extra_decode_length"]
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
......@@ -216,9 +217,9 @@ class Transformer(object):
# Create cache storing decoder attention values for each layer.
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params.hidden_size]),
"v": tf.zeros([batch_size, 0, self.params.hidden_size]),
} for layer in range(self.params.num_hidden_layers)}
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
} for layer in range(self.params["num_hidden_layers"])}
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
......@@ -229,9 +230,9 @@ class Transformer(object):
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params.vocab_size,
beam_size=self.params.beam_size,
alpha=self.params.alpha,
vocab_size=self.params["vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID)
......@@ -268,11 +269,11 @@ class PrePostProcessingWrapper(object):
def __init__(self, layer, params, train):
self.layer = layer
self.postprocess_dropout = params.layer_postprocess_dropout
self.postprocess_dropout = params["layer_postprocess_dropout"]
self.train = train
# Create normalization layer
self.layer_norm = LayerNormalization(params.hidden_size)
self.layer_norm = LayerNormalization(params["hidden_size"])
def __call__(self, x, *args, **kwargs):
# Preprocessing: apply layer normalization
......@@ -299,19 +300,21 @@ class EncoderStack(tf.layers.Layer):
def __init__(self, params, train):
super(EncoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
for _ in range(params["num_hidden_layers"]):
# Create sublayers for each layer.
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
# Create final layer normalization layer.
self.output_normalization = LayerNormalization(params.hidden_size)
self.output_normalization = LayerNormalization(params["hidden_size"])
def call(self, encoder_inputs, attention_bias, inputs_padding):
"""Return the output of the encoder layer stacks.
......@@ -354,20 +357,23 @@ class DecoderStack(tf.layers.Layer):
def __init__(self, params, train):
super(DecoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
for _ in range(params["num_hidden_layers"]):
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
enc_dec_attention_layer = attention_layer.Attention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(enc_dec_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
self.output_normalization = LayerNormalization(params.hidden_size)
self.output_normalization = LayerNormalization(params["hidden_size"])
def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias, cache=None):
......
......@@ -22,6 +22,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import tempfile
......@@ -39,7 +40,9 @@ from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import schedule
from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
......@@ -47,8 +50,9 @@ from official.utils.misc import model_helpers
PARAMS_MAP = {
"base": model_params.TransformerBaseParams,
"big": model_params.TransformerBigParams,
"tiny": model_params.TINY_PARAMS,
"base": model_params.BASE_PARAMS,
"big": model_params.BIG_PARAMS,
}
DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
......@@ -69,37 +73,71 @@ def model_fn(features, labels, mode, params):
# Create model and get output logits.
model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)
output = model(inputs, targets)
logits = model(inputs, targets)
# When in prediction mode, the labels/targets is None. The model output
# is the prediction
if mode == tf.estimator.ModeKeys.PREDICT:
if params["use_tpu"]:
raise NotImplementedError("Prediction is not yet supported on TPUs.")
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=output)
predictions=logits)
logits = output
# Explicitly set the shape of the logits for XLA (TPU). This is needed
# because the logits are passed back to the host VM CPU for metric
# evaluation, and the shape of [?, ?, vocab_size] is too vague. However
# it is known from Transformer that the first two dimensions of logits
# are the dimensions of targets. Note that the ambiguous shape of logits is
# not a problem when computing xentropy, because padded_cross_entropy_loss
# resolves the shape on the TPU.
logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])
# Calculate model loss.
# xentropy contains the cross entropy loss of every nonpadding token in the
# targets.
xentropy, weights = metrics.padded_cross_entropy_loss(
logits, targets, params.label_smoothing, params.vocab_size)
# Compute the weighted mean of the cross entropy losses
logits, targets, params["label_smoothing"], params["vocab_size"])
loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
# Save loss as named tensor that will be logged with the logging hook.
tf.identity(loss, "cross_entropy")
if mode == tf.estimator.ModeKeys.EVAL:
if params["use_tpu"]:
# host call functions should only have tensors as arguments.
# functools.partial() pre-populates params so that metric_fn is
# TPUEstimator compliant.
metric_fn = functools.partial(metrics.get_eval_metrics, params=params)
eval_metrics = (metric_fn, [logits, labels])
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
eval_metrics=eval_metrics)
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
else:
train_op = get_train_op(loss, params)
train_op, metric_dict = get_train_op_and_metrics(loss, params)
# Epochs can be quite long. This gives some intermediate information
# in TensorBoard.
metric_dict["minibatch_loss"] = loss
if params["use_tpu"]:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op,
host_call=tpu_util.construct_scalar_host_call(
metric_dict=metric_dict, model_dir=params["model_dir"],
prefix="training/")
)
record_scalars(metric_dict)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def record_scalars(metric_dict):
for key, value in metric_dict.items():
tf.contrib.summary.scalar(name=key, tensor=value)
def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
"""Calculate learning rate with linear warmup and rsqrt decay."""
with tf.name_scope("learning_rate"):
......@@ -116,26 +154,28 @@ def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
# The full name includes variable and names scope. In this case, the name
# is model/get_train_op/learning_rate/learning_rate
tf.identity(learning_rate, "learning_rate")
# Save learning rate value to TensorBoard summary.
tf.summary.scalar("learning_rate", learning_rate)
return learning_rate
def get_train_op(loss, params):
"""Generate training operation that updates variables based on loss."""
def get_train_op_and_metrics(loss, params):
"""Generate training op and metrics to save in TensorBoard."""
with tf.variable_scope("get_train_op"):
learning_rate = get_learning_rate(
params.learning_rate, params.hidden_size,
params.learning_rate_warmup_steps)
learning_rate=params["learning_rate"],
hidden_size=params["hidden_size"],
learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
# than the TF core Adam optimizer.
optimizer = tf.contrib.opt.LazyAdamOptimizer(
learning_rate,
beta1=params.optimizer_adam_beta1,
beta2=params.optimizer_adam_beta2,
epsilon=params.optimizer_adam_epsilon)
beta1=params["optimizer_adam_beta1"],
beta2=params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
# Calculate and apply gradients using LazyAdamOptimizer.
global_step = tf.train.get_global_step()
......@@ -145,11 +185,15 @@ def get_train_op(loss, params):
train_op = optimizer.apply_gradients(
gradients, global_step=global_step, name="train")
# Save gradient norm to Tensorboard
tf.summary.scalar("global_norm/gradient_norm",
tf.global_norm(list(zip(*gradients))[0]))
metrics = {"learning_rate": learning_rate}
if not params["use_tpu"]:
# gradient norm is not included as a summary when running on TPU, as
# it can cause instability between the TPU and the host controller.
gradient_norm = tf.global_norm(list(zip(*gradients))[0])
metrics["global_norm/gradient_norm"] = gradient_norm
return train_op
return train_op, metrics
def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
......@@ -186,9 +230,8 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
return uncased_score, cased_score
def train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps=None,
single_iteration_train_epochs=None, train_hooks=None, benchmark_logger=None,
def run_loop(
estimator, schedule_manager, train_hooks=None, benchmark_logger=None,
bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None):
"""Train and evaluate model, and optionally compute model's BLEU score.
......@@ -215,41 +258,25 @@ def train_schedule(
Args:
estimator: tf.Estimator containing model to train.
train_eval_iterations: Number of times to repeat the train+eval iteration.
single_iteration_train_steps: Number of steps to train in one iteration.
single_iteration_train_epochs: Number of epochs to train in one iteration.
schedule_manager: A schedule.Manager object to guide the run loop.
train_hooks: List of hooks to pass to the estimator during training.
benchmark_logger: a BenchmarkLogger object that logs evaluation data
bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped.
vocab_file_path: Path to vocabulary file used to subtokenize bleu_source.
Raises:
ValueError: if both or none of single_iteration_train_steps and
single_iteration_train_epochs were defined.
"""
# Ensure that exactly one of single_iteration_train_steps and
# single_iteration_train_epochs is defined.
if single_iteration_train_steps is None:
if single_iteration_train_epochs is None:
raise ValueError(
"Exactly one of single_iteration_train_steps or "
"single_iteration_train_epochs must be defined. Both were none.")
else:
if single_iteration_train_epochs is not None:
raise ValueError(
"Exactly one of single_iteration_train_steps or "
"single_iteration_train_epochs must be defined. Both were defined.")
evaluate_bleu = bleu_source is not None and bleu_ref is not None
if evaluate_bleu and schedule_manager.use_tpu:
raise ValueError("BLEU score can not be computed when training with a TPU, "
"as it requires estimator.predict which is not yet "
"supported.")
# Print details of training schedule.
tf.logging.info("Training schedule:")
if single_iteration_train_epochs is not None:
tf.logging.info("\t1. Train for %d epochs." % single_iteration_train_epochs)
else:
tf.logging.info("\t1. Train for %d steps." % single_iteration_train_steps)
tf.logging.info(
"\t1. Train for {}".format(schedule_manager.train_increment_str))
tf.logging.info("\t2. Evaluate model.")
if evaluate_bleu:
tf.logging.info("\t3. Compute BLEU score.")
......@@ -257,7 +284,8 @@ def train_schedule(
tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
bleu_threshold)
if not evaluate_bleu or bleu_threshold is None:
tf.logging.info("Repeat above steps %d times." % train_eval_iterations)
tf.logging.info("Repeat above steps %d times." %
schedule_manager.train_eval_iterations)
if evaluate_bleu:
# Create summary writer to log bleu score (values can be displayed in
......@@ -266,21 +294,25 @@ def train_schedule(
os.path.join(estimator.model_dir, BLEU_DIR))
if bleu_threshold is not None:
# Change loop stopping condition if bleu_threshold is defined.
train_eval_iterations = INF
schedule_manager.train_eval_iterations = INF
# Loop training/evaluation/bleu cycles
for i in xrange(train_eval_iterations):
for i in xrange(schedule_manager.train_eval_iterations):
tf.logging.info("Starting iteration %d" % (i + 1))
# Train the model for single_iteration_train_steps or until the input fn
# runs out of examples (if single_iteration_train_steps is None).
estimator.train(
dataset.train_input_fn, steps=single_iteration_train_steps,
dataset.train_input_fn,
steps=schedule_manager.single_iteration_train_steps,
hooks=train_hooks)
eval_results = estimator.evaluate(dataset.eval_input_fn)
eval_results = estimator.evaluate(
input_fn=dataset.eval_input_fn,
steps=schedule_manager.single_iteration_eval_steps)
tf.logging.info("Evaluation results (iter %d/%d):" %
(i + 1, train_eval_iterations))
(i + 1, schedule_manager.train_eval_iterations))
tf.logging.info(eval_results)
benchmark_logger.log_evaluation_result(eval_results)
......@@ -325,6 +357,7 @@ def define_transformer_flags():
dtype=False
)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
# Set flags from the flags_core module as "key flags" so they're listed when
# the '-h' flag is used. Without this line, the flags defined above are
......@@ -334,7 +367,7 @@ def define_transformer_flags():
# Add transformer-specific flags
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
enum_values=["base", "big", "tiny"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
......@@ -343,6 +376,16 @@ def define_transformer_flags():
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_bool(
name="static_batch", default=False,
help=flags_core.help_wrap(
"Whether the batches in the dataset should have static shapes. In "
"general, this setting should be False. Dynamic shapes allow the "
"inputs to be grouped so that the number of padding tokens is "
"minimized, and helps model training. In cases where the input shape "
"must be static (e.g. running on TPU), this setting will be ignored "
"and static batching will always be used."))
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
name="train_steps", short_name="ts", default=None,
......@@ -403,6 +446,51 @@ def define_transformer_flags():
tf.gfile.Exists(flags_dict["bleu_ref"]),
tf.gfile.Exists(vocab_file_path)])
flags_core.require_cloud_storage(["data_dir", "model_dir"])
def construct_estimator(flags_obj, params, schedule_manager):
"""Construct an estimator from either Estimator or TPUEstimator.
Args:
flags_obj: The FLAGS object parsed from command line.
params: A dict of run specific parameters.
schedule_manager: A schedule.Manager object containing the run schedule.
Returns:
An estimator object to be used for training and eval.
"""
if not params["use_tpu"]:
return tf.estimator.Estimator(
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
tpu=flags_obj.tpu,
zone=flags_obj.tpu_zone,
project=flags_obj.tpu_gcp_project
)
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=schedule_manager.single_iteration_train_steps,
num_shards=flags_obj.num_tpu_shards)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=flags_obj.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config)
return tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
train_batch_size=schedule_manager.batch_size,
eval_batch_size=schedule_manager.batch_size,
params={
# TPUEstimator needs to populate batch_size itself due to sharding.
key: value for key, value in params.items() if key != "batch_size"},
config=run_config)
def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
......@@ -410,49 +498,54 @@ def run_transformer(flags_obj):
Args:
flags_obj: Object containing parsed flag values.
"""
# Determine training schedule based on flags.
if flags_obj.train_steps is not None:
train_eval_iterations = (
flags_obj.train_steps // flags_obj.steps_between_evals)
single_iteration_train_steps = flags_obj.steps_between_evals
single_iteration_train_epochs = None
else:
train_epochs = flags_obj.train_epochs or DEFAULT_TRAIN_EPOCHS
train_eval_iterations = train_epochs // flags_obj.epochs_between_evals
single_iteration_train_steps = None
single_iteration_train_epochs = flags_obj.epochs_between_evals
# Add flag-defined parameters to params object
params = PARAMS_MAP[flags_obj.param_set]
params.data_dir = flags_obj.data_dir
params.num_parallel_calls = flags_obj.num_parallel_calls
params.epochs_between_evals = flags_obj.epochs_between_evals
params.repeat_dataset = single_iteration_train_epochs
params.batch_size = flags_obj.batch_size or params.batch_size
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["num_parallel_calls"] = flags_obj.num_parallel_calls
params["tpu"] = flags_obj.tpu
params["use_tpu"] = bool(flags_obj.tpu) # was a tpu specified.
params["batch_size"] = flags_obj.batch_size or (
params["default_batch_size_tpu"] if params["use_tpu"]
else params["default_batch_size"])
params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
params["allow_ffn_pad"] = not params["use_tpu"]
schedule_manager = schedule.Manager(
train_steps=flags_obj.train_steps,
steps_between_evals=flags_obj.steps_between_evals,
train_epochs=flags_obj.train_epochs,
epochs_between_evals=flags_obj.epochs_between_evals,
default_train_epochs=DEFAULT_TRAIN_EPOCHS,
batch_size=params["batch_size"],
max_length=params["max_length"],
use_tpu=params["use_tpu"],
num_tpu_shards=flags_obj.num_tpu_shards
)
params["repeat_dataset"] = schedule_manager.repeat_dataset
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook
batch_size=schedule_manager.batch_size, # for ExamplesPerSecondHook
use_tpu=params["use_tpu"] # Not all hooks can run with TPUs
)
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params.__dict__,
run_params=params,
test_id=flags_obj.benchmark_test_id)
# Train and evaluate transformer model
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)
train_schedule(
estimator = construct_estimator(flags_obj, params, schedule_manager)
run_loop(
estimator=estimator,
# Training arguments
train_eval_iterations=train_eval_iterations,
single_iteration_train_steps=single_iteration_train_steps,
single_iteration_train_epochs=single_iteration_train_epochs,
schedule_manager=schedule_manager,
train_hooks=train_hooks,
benchmark_logger=benchmark_logger,
# BLEU calculation arguments
......
......@@ -134,8 +134,8 @@ def translate_file(
"file.")
tf.logging.info("Writing to file %s" % output_file)
with tf.gfile.Open(output_file, "w") as f:
for index in xrange(len(sorted_keys)):
f.write("%s\n" % translations[sorted_keys[index]])
for index, key in enumerate(sorted_keys):
f.write("%s\n" % translations[key])
def translate_text(estimator, subtokenizer, txt):
......@@ -168,10 +168,10 @@ def main(unused_argv):
# Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set]
params.beam_size = _BEAM_SIZE
params.alpha = _ALPHA
params.extra_decode_length = _EXTRA_DECODE_LENGTH
params.batch_size = _DECODE_BATCH_SIZE
params["beam_size"] = _BEAM_SIZE
params["alpha"] = _ALPHA
params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
params["batch_size"] = _DECODE_BATCH_SIZE
estimator = tf.estimator.Estimator(
model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
params=params)
......
......@@ -190,7 +190,8 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat):
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
static_batch=False):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
......@@ -201,6 +202,17 @@ def _read_and_batch_from_files(
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
static_batch: Whether the batches in the dataset should have static shapes.
If True, the input is batched so that every batch has the
shape [batch_size // max_length, max_length]. If False, the input is
grouped by length, and batched so that batches may have different
shapes [N, M], where:
N * M <= batch_size
M <= max_length
In general, this setting should be False. Dynamic shapes allow the inputs
to be grouped so that the number of padding tokens is minimized, and helps
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
Returns:
tf.data.Dataset object containing examples loaded from the files.
......@@ -225,8 +237,13 @@ def _read_and_batch_from_files(
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
# Batch such that each batch has examples of similar length.
if static_batch:
dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder(
batch_size // max_length, ([max_length], [max_length])))
else:
# Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length)
dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline.
......@@ -236,15 +253,17 @@ def _read_and_batch_from_files(
def train_input_fn(params):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*")
file_pattern = os.path.join(params.get("data_dir", ""), "*train*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=True, repeat=params.repeat_dataset)
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"])
def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*")
file_pattern = os.path.join(params.get("data_dir", ""), "*dev*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length,
params.num_parallel_calls, shuffle=False, repeat=1)
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1,
static_batch=params["static_batch"])
......@@ -118,12 +118,20 @@ def get_eval_metrics(logits, labels, params):
"accuracy_per_sequence": _convert_to_eval_metric(
padded_sequence_accuracy)(logits, labels),
"neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
logits, labels, params.vocab_size),
"approx_bleu_score": _convert_to_eval_metric(bleu_score)(logits, labels),
"rouge_2_fscore": _convert_to_eval_metric(rouge_2_fscore)(logits, labels),
"rouge_L_fscore": _convert_to_eval_metric(rouge_l_fscore)(logits, labels),
logits, labels, params["vocab_size"]),
}
if not params["use_tpu"]:
# TPU does not support tf.py_func
metrics.update({
"approx_bleu_score": _convert_to_eval_metric(
bleu_score)(logits, labels),
"rouge_2_fscore": _convert_to_eval_metric(
rouge_2_fscore)(logits, labels),
"rouge_L_fscore": _convert_to_eval_metric(
rouge_l_fscore)(logits, labels),
})
# Prefix each of the metric names with "metrics/". This allows the metric
# graphs to display under the "metrics" category in TensorBoard.
metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract training on a step or epoch basis."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import time
import tensorflow as tf
from official.transformer.utils import dataset
_TRAIN, _EVAL = tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
NUM_EXAMPLES = {
tf.estimator.ModeKeys.TRAIN: 4572160,
# # Examples that are too long are filtered out, thus the total is less
# # than the total number of lines.
# 2399123 + # news-commentary-v12.de-en
# 1920209 + # commoncrawl.de-en
# 270769, # europarl-v7.de-en
tf.estimator.ModeKeys.EVAL: 3000, # newstest2013
}
class Manager(object):
"""Container for convenience functions to abstract step or epoch basis.
Transformer allows users to specify an epoch basis (generally recommended for
full training) or a number of steps basis (convenient since epochs are rather
large). TPUs furthermore require a step basis; however epochs are the norm in
the machine learning community and it is desirable to allow users to specify
epochs even when running with TPUS which requires behind the scenes
conversions.
This container simply groups what are largely mundane checks and conversions
rather than interspersing them throughout the run loop code.
"""
def __init__(self, train_steps, steps_between_evals, train_epochs,
epochs_between_evals, default_train_epochs, batch_size,
max_length, use_tpu=False, num_tpu_shards=8):
if train_steps and train_epochs:
raise ValueError("Both train_steps or train_epochs were be defined.")
# Determine training schedule based on flags.
if train_steps:
self.train_eval_iterations = train_steps // steps_between_evals
self._single_iteration_train_steps = steps_between_evals
self._single_iteration_train_epochs = None
else:
train_epochs = train_epochs or default_train_epochs
self.train_eval_iterations = train_epochs // epochs_between_evals
self._single_iteration_train_steps = None
self._single_iteration_train_epochs = epochs_between_evals
self.max_length = max_length
self.batch_size = batch_size
self.use_tpu = use_tpu
self.num_tpu_shards = num_tpu_shards
if self.use_tpu:
assert (self.batch_size // self.max_length) % self.num_tpu_shards == 0
@property
def single_iteration_train_steps(self):
if self._single_iteration_train_steps or not self.use_tpu:
return self._single_iteration_train_steps
return self.epochs_to_steps(
num_epochs=self._single_iteration_train_epochs, mode=_TRAIN)
@property
def single_iteration_eval_steps(self):
if not self.use_tpu:
return None
return self.epochs_to_steps(num_epochs=1, mode=_EVAL)
@property
def train_increment_str(self):
if self._single_iteration_train_steps:
return "{} steps.".format(self._single_iteration_train_steps)
if not self.use_tpu:
return "{} epochs.".format(self._single_iteration_train_epochs)
return "~{} epochs. ({} steps)".format(
self._single_iteration_train_epochs,
self.single_iteration_train_steps)
@property
def repeat_dataset(self):
if (self._single_iteration_train_epochs is None and
self._single_iteration_train_steps > NUM_EXAMPLES[_TRAIN]):
return math.ceil(self._single_iteration_train_steps /
NUM_EXAMPLES[_TRAIN])
return self._single_iteration_train_epochs
def epochs_to_steps(self, num_epochs, mode):
"""Converts a number of epochs to a number of training steps.
TPU only: This function assumes that static_batch is True.
TPU can not tolerate an OutOfRange error from a dataset. As a result the
number of examples to be processed must be known ahead of time. TPUs also
do not allow partial batches, so this function rounds down.
Args:
num_epochs: An integer of the number of epochs to convert to steps.
batch_size: The mini-batch size used.
mode: The estimator ModeKey of the computation
Returns:
An integer of the number of equivalent steps rounded down.
"""
assert self.use_tpu, "epochs_to_steps should only be reached when using TPU"
total_num_tokens = NUM_EXAMPLES[mode] * self.max_length * num_epochs
return total_num_tokens // self.batch_size
def _sleep_if_tpu(self):
"""Sleep for a minute if TPUs are used.
There is currently an issue with TPUs where starting a train or evaluation
before all of the TPU queues have cleared causes the TPU to freeze. This
is a temporary workaround until the issue can be properly resolved.
"""
if self.use_tpu:
tf.logging.info("Sleeping to allow TPU queues to clear.")
time.sleep(60)
def post_train(self):
self._sleep_if_tpu()
def post_eval(self):
self._sleep_if_tpu()
......@@ -203,7 +203,7 @@ def _load_vocab_file(vocab_file, reserved_tokens=None):
def _native_to_unicode(s):
"""Convert string to unicode (required in Python 2)."""
if six.PY2:
return s if isinstance(s, unicode) else s.decode("utf-8")
return s if isinstance(s, unicode) else s.decode("utf-8") # pylint: disable=undefined-variable
else:
return s
......@@ -211,7 +211,7 @@ def _native_to_unicode(s):
def _unicode_to_native(s):
"""Convert string from unicode to native format (required in Python 2)."""
if six.PY2:
return s.encode("utf-8") if isinstance(s, unicode) else s
return s.encode("utf-8") if isinstance(s, unicode) else s # pylint: disable=undefined-variable
else:
return s
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""
import time
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# to use the local CPU as the compute device. This is useful for testing and
# debugging; the code flow is ostensibly identical, but without the need to
# actually have a TPU on the other end.
LOCAL = "local"
def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
"""Construct a host call to log scalars when training on TPU.
Args:
metric_dict: A dict of the tensors to be logged.
model_dir: The location to write the summary.
prefix: The prefix (if any) to prepend to the metric names.
Returns:
A tuple of (function, args_to_be_passed_to_said_function)
"""
# type: (dict, str) -> (function, list)
metric_names = list(metric_dict.keys())
def host_call_fn(global_step, *args):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: `Tensor with shape `[batch]` for the global_step
*args: Remaining tensors to log.
Returns:
List of summary ops to run on the CPU host.
"""
step = global_step[0]
with tf.contrib.summary.create_file_writer(
logdir=model_dir, filename_suffix=".host_call").as_default():
with tf.contrib.summary.always_record_summaries():
for i, name in enumerate(metric_names):
tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
return tf.contrib.summary.all_summary_ops()
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors
def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
"""Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted
range(num_embedding_table_rows).
Args:
embedding_table: Tensor of embedding table.
Rank 2 (table_size x embedding dim)
values: Tensor of embedding indices. Rank 2 (batch x n_indices)
mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
name: Optional name scope for created ops
Returns:
Rank 3 tensor of embedding vectors.
"""
with ops.name_scope(name):
n_embeddings, embedding_dim = embedding_table.get_shape().as_list()
batch_size, padded_size = values.shape.as_list()
emb_idcs = array_ops.tile(
array_ops.reshape(values, (batch_size, padded_size, 1)), (1, 1,
n_embeddings))
emb_weights = array_ops.tile(
array_ops.reshape(mask, (batch_size, padded_size, 1)),
(1, 1, n_embeddings))
col_idcs = array_ops.tile(
array_ops.reshape(math_ops.range(n_embeddings), (1, 1, n_embeddings)),
(batch_size, padded_size, 1))
one_hot = array_ops.where(
math_ops.equal(emb_idcs, col_idcs), emb_weights,
array_ops.zeros((batch_size, padded_size, n_embeddings)))
return math_ops.tensordot(one_hot, embedding_table, 1)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for managing compute devices. Currently only contains TPU flags."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def require_cloud_storage(flag_names):
"""Register a validator to check directory flags.
Args:
flag_names: An iterable of strings containing the names of flags to be
checked.
"""
msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
@flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
def _path_check(flag_values): # pylint: disable=missing-docstring
if flag_values["tpu"] is None:
return True
valid_flags = True
for key in flag_names:
if not flag_values[key].startswith("gs://"):
tf.logging.error("{} must be a GCS path.".format(key))
valid_flags = False
return valid_flags
def define_device(tpu=True):
"""Register device specific flags.
Args:
tpu: Create flags to specify TPU operation.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if tpu:
flags.DEFINE_string(
name="tpu", default=None,
help=help_wrap(
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a "
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
"CPU of the local instance instead. (Good for debugging.)"))
key_flags.append("tpu")
flags.DEFINE_string(
name="tpu_zone", default=None,
help=help_wrap(
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_string(
name="tpu_gcp_project", default=None,
help=help_wrap(
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."))
flags.DEFINE_integer(name="num_tpu_shards", default=8,
help=help_wrap("Number of shards (TPU chips)."))
return key_flags
......@@ -30,6 +30,7 @@ from absl import flags
from official.utils.flags import _base
from official.utils.flags import _benchmark
from official.utils.flags import _conventions
from official.utils.flags import _device
from official.utils.flags import _misc
from official.utils.flags import _performance
......@@ -72,6 +73,7 @@ define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, epochs_between_evals=False, stop_threshold=False,
multi_gpu=False, hooks=False))
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance)
......@@ -83,3 +85,4 @@ get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP
require_cloud_storage = _device.require_cloud_storage
......@@ -35,13 +35,15 @@ _TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'train_accuracy'])
def get_train_hooks(name_list, **kwargs):
def get_train_hooks(name_list, use_tpu=False, **kwargs):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
use_tpu: Boolean of whether computation occurs on a TPU. This will disable
hooks altogether.
**kwargs: a dictionary of arguments to the hooks.
Returns:
......@@ -54,6 +56,11 @@ def get_train_hooks(name_list, **kwargs):
if not name_list:
return []
if use_tpu:
tf.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used.".format(name_list))
return []
train_hooks = []
for name in name_list:
hook_name = HOOKS.get(name.strip().lower())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment