Commit 441c8f40 authored by qianyj's avatar qianyj
Browse files

update TF code

parent ec90ad8e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of embedding layer with shared weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
from official.utils.accelerator import tpu as tpu_utils
class EmbeddingSharedWeights(tf.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size, method="gather"):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
method: Strategy for performing embedding lookup. "gather" uses tf.gather
which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul"
one-hot encodes the indicies and formulates the embedding as a sparse
matrix multiplication. The matmul formulation is wasteful as it does
extra work, however matrix multiplication is very fast on TPUs which
makes "matmul" considerably faster than "gather" on TPUs.
"""
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
if method not in ("gather", "matmul"):
raise ValueError("method {} must be 'gather' or 'matmul'".format(method))
self.method = method
def build(self, _):
with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
# Create and initialize weights. The random normal initializer was chosen
# randomly, and works well.
self.shared_weights = tf.get_variable(
"weights", [self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
0., self.hidden_size ** -0.5))
self.built = True
def call(self, x):
"""Get token embeddings of x.
Args:
x: An int64 tensor with shape [batch_size, length]
Returns:
embeddings: float32 tensor with shape [batch_size, length, embedding_size]
padding: float32 tensor with shape [batch_size, length] indicating the
locations of the padding tokens in x.
"""
with tf.name_scope("embedding"):
# Create binary mask of size [batch_size, length]
mask = tf.to_float(tf.not_equal(x, 0))
if self.method == "gather":
embeddings = tf.gather(self.shared_weights, x)
embeddings *= tf.expand_dims(mask, -1)
else: # matmul
embeddings = tpu_utils.embedding_matmul(
embedding_table=self.shared_weights,
values=tf.cast(x, dtype=tf.int32),
mask=mask
)
# embedding_matmul already zeros out masked positions, so
# `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.
# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5
return embeddings
def linear(self, x):
"""Computes logits by running x through a linear layer.
Args:
x: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
with tf.name_scope("presoftmax_linear"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
x = tf.reshape(x, [-1, self.hidden_size])
logits = tf.matmul(x, self.shared_weights, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of fully connected network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class FeedFowardNetwork(tf.layers.Layer):
"""Fully connected feedforward network."""
def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad):
super(FeedFowardNetwork, self).__init__()
self.hidden_size = hidden_size
self.filter_size = filter_size
self.relu_dropout = relu_dropout
self.train = train
self.allow_pad = allow_pad
self.filter_dense_layer = tf.layers.Dense(
filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
self.output_dense_layer = tf.layers.Dense(
hidden_size, use_bias=True, name="output_layer")
def call(self, x, padding=None):
"""Return outputs of the feedforward network.
Args:
x: tensor with shape [batch_size, length, hidden_size]
padding: (optional) If set, the padding values are temporarily removed
from x (provided self.allow_pad is set). The padding values are placed
back in the output tensor in the same locations.
shape [batch_size, length]
Returns:
Output of the feedforward network.
tensor with shape [batch_size, length, hidden_size]
"""
padding = None if not self.allow_pad else padding
# Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
if padding is not None:
with tf.name_scope("remove_padding"):
# Flatten padding to [batch_size*length]
pad_mask = tf.reshape(padding, [-1])
nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9))
# Reshape x to [batch_size*length, hidden_size] to remove padding
x = tf.reshape(x, [-1, self.hidden_size])
x = tf.gather_nd(x, indices=nonpad_ids)
# Reshape x from 2 dimensions to 3 dimensions.
x.set_shape([None, self.hidden_size])
x = tf.expand_dims(x, axis=0)
output = self.filter_dense_layer(x)
if self.train:
output = tf.nn.dropout(output, 1.0 - self.relu_dropout)
output = self.output_dense_layer(output)
if padding is not None:
with tf.name_scope("re_add_padding"):
output = tf.squeeze(output, axis=0)
output = tf.scatter_nd(
indices=nonpad_ids,
updates=output,
shape=[batch_size * length, self.hidden_size]
)
output = tf.reshape(output, [batch_size, length, self.hidden_size])
return output
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Transformer model parameters."""
from collections import defaultdict
BASE_PARAMS = defaultdict(
lambda: None, # Set default value to None.
# Input params
default_batch_size=2048, # Maximum number of tokens per batch of examples.
default_batch_size_tpu=32768,
max_length=256, # Maximum number of tokens per example.
# Model params
initializer_gain=1.0, # Used in trainable variable initialization.
vocab_size=33708, # Number of tokens defined in the vocabulary file.
hidden_size=512, # Model dimension in the hidden layers.
num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
num_heads=8, # Number of heads to use in multi-headed attention.
filter_size=2048, # Inner layer dimension in the feedforward network.
# Dropout values (only used when training)
layer_postprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
# Training params
label_smoothing=0.1,
learning_rate=2.0,
learning_rate_decay_rate=1.0,
learning_rate_warmup_steps=16000,
# Optimizer params
optimizer_adam_beta1=0.9,
optimizer_adam_beta2=0.997,
optimizer_adam_epsilon=1e-09,
# Default prediction params
extra_decode_length=50,
beam_size=4,
alpha=0.6, # used to calculate length normalization in beam search
# TPU specific parameters
use_tpu=False,
static_batch=False,
allow_ffn_pad=True,
)
BIG_PARAMS = BASE_PARAMS.copy()
BIG_PARAMS.update(
default_batch_size=4096,
# default batch size is smaller than for BASE_PARAMS due to memory limits.
default_batch_size_tpu=16384,
hidden_size=1024,
filter_size=4096,
num_heads=16,
)
# Parameters for running the model in multi gpu. These should not change the
# params that modify the model shape (such as the hidden_size or num_heads).
BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy()
BASE_MULTI_GPU_PARAMS.update(
learning_rate_warmup_steps=8000
)
BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy()
BIG_MULTI_GPU_PARAMS.update(
layer_postprocess_dropout=0.3,
learning_rate_warmup_steps=8000
)
# Parameters for testing the model
TINY_PARAMS = BASE_PARAMS.copy()
TINY_PARAMS.update(
default_batch_size=1024,
default_batch_size_tpu=1024,
hidden_size=32,
num_heads=4,
filter_size=256,
)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer model helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
_NEG_INF = -1e9
def get_position_encoding(
length, hidden_size, min_timescale=1.0, max_timescale=1.0e4):
"""Return positional encoding.
Calculates the position encoding as a mix of sine and cosine functions with
geometrically increasing wavelengths.
Defined and formulized in Attention is All You Need, section 3.5.
Args:
length: Sequence length.
hidden_size: Size of the
min_timescale: Minimum scale that will be applied at each position
max_timescale: Maximum scale that will be applied at each position
Returns:
Tensor with shape [length, hidden_size]
"""
position = tf.to_float(tf.range(length))
num_timescales = hidden_size // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return signal
def get_decoder_self_attention_bias(length):
"""Calculate bias for decoder that maintains model's autoregressive property.
Creates a tensor that masks out locations that correspond to illegal
connections, so prediction at position i cannot draw information from future
positions.
Args:
length: int length of sequences in batch.
Returns:
float tensor of shape [1, 1, length, length]
"""
with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.matrix_band_part(tf.ones([length, length]), -1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = _NEG_INF * (1.0 - valid_locs)
return decoder_bias
def get_padding(x, padding_value=0):
"""Return float tensor representing the padding values in x.
Args:
x: int tensor with any shape
padding_value: int value that
Returns:
flaot tensor with same shape as x containing values 0 or 1.
0 -> non-padding, 1 -> padding
"""
with tf.name_scope("padding"):
return tf.to_float(tf.equal(x, padding_value))
def get_padding_bias(x):
"""Calculate bias tensor from padding values in tensor.
Bias tensor that is added to the pre-softmax multi-headed attention logits,
which has shape [batch_size, num_heads, length, length]. The tensor is zero at
non-padding locations, and -1e9 (negative infinity) at padding locations.
Args:
x: int tensor with shape [batch_size, length]
Returns:
Attention bias tensor of shape [batch_size, 1, 1, length].
"""
with tf.name_scope("attention_bias"):
padding = get_padding(x)
attention_bias = padding * _NEG_INF
attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1)
return attention_bias
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer model helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
NEG_INF = -1e9
class ModelUtilsTest(tf.test.TestCase):
def test_get_padding(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
padding = model_utils.get_padding(x, padding_value=0)
with self.test_session() as sess:
padding = sess.run(padding)
self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
padding)
def test_get_padding_bias(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
bias = model_utils.get_padding_bias(x)
bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5])
with self.test_session() as sess:
flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape))
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[NEG_INF, 0, 0, NEG_INF, 0]],
flattened_bias)
self.assertAllEqual([3, 1, 1, 5], bias_shape)
def test_get_decoder_self_attention_bias(self):
length = 5
bias = model_utils.get_decoder_self_attention_bias(length)
with self.test_session() as sess:
bias = sess.run(bias)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[0, 0, 0, NEG_INF, NEG_INF],
[0, 0, 0, 0, NEG_INF],
[0, 0, 0, 0, 0]]]],
bias)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model, and its encoder and decoder stacks.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import attention_layer
from official.transformer.model import beam_search
from official.transformer.model import embedding_layer
from official.transformer.model import ffn_layer
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID
_NEG_INF = -1e9
class Transformer(object):
"""Transformer model for sequence to sequence data.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def __init__(self, params, train):
"""Initialize layers to build Transformer model.
Args:
params: hyperparameter object defining layer sizes, dropout values, etc.
train: boolean indicating whether the model is in training mode. Used to
determine if dropout layers should be added.
"""
self.train = train
self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"],
method="matmul" if params["tpu"] else "gather")
self.encoder_stack = EncoderStack(params, train)
self.decoder_stack = DecoderStack(params, train)
def __call__(self, inputs, targets=None):
"""Calculate target logits or inferred target sequences.
Args:
inputs: int tensor with shape [batch_size, input_length].
targets: None or int tensor with shape [batch_size, target_length].
Returns:
If targets is defined, then return logits for each word in the target
sequence. float tensor with shape [batch_size, target_length, vocab_size]
If target is none, then generate output sequence one token at a time.
returns a dictionary {
output: [batch_size, decoded length]
score: [batch_size, float]}
"""
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
initializer = tf.variance_scaling_initializer(
self.params["initializer_gain"], mode="fan_avg", distribution="uniform")
with tf.variable_scope("Transformer", initializer=initializer):
# Calculate attention bias for encoder self-attention and decoder
# multi-headed attention layers.
attention_bias = model_utils.get_padding_bias(inputs)
# Run the inputs through the encoder layer to map the symbol
# representations to continuous representations.
encoder_outputs = self.encode(inputs, attention_bias)
# Generate output sequence if targets is None, or return logits if target
# sequence is known.
if targets is None:
return self.predict(encoder_outputs, attention_bias)
else:
logits = self.decode(targets, encoder_outputs, attention_bias)
return logits
def encode(self, inputs, attention_bias):
"""Generate continuous representation for inputs.
Args:
inputs: int tensor with shape [batch_size, input_length].
attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
Returns:
float tensor with shape [batch_size, input_length, hidden_size]
"""
with tf.name_scope("encode"):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_softmax_layer(inputs)
inputs_padding = model_utils.get_padding(inputs)
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params["hidden_size"])
encoder_inputs = embedded_inputs + pos_encoding
if self.train:
encoder_inputs = tf.nn.dropout(
encoder_inputs, 1 - self.params["layer_postprocess_dropout"])
return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)
def decode(self, targets, encoder_outputs, attention_bias):
"""Generate logits for each value in the target sequence.
Args:
targets: target values for the output sequence.
int tensor with shape [batch_size, target_length]
encoder_outputs: continuous representation of input sequence.
float tensor with shape [batch_size, input_length, hidden_size]
attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
Returns:
float32 tensor with shape [batch_size, target_length, vocab_size]
"""
with tf.name_scope("decode"):
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(
decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding(
length, self.params["hidden_size"])
if self.train:
decoder_inputs = tf.nn.dropout(
decoder_inputs, 1 - self.params["layer_postprocess_dropout"])
# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
length)
outputs = self.decoder_stack(
decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias)
logits = self.embedding_softmax_layer.linear(outputs)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params["hidden_size"])
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length)
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences.
int tensor with shape [batch_size * beam_size, i + 1]
i: Loop index
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape [batch_size * beam_size, vocab_size],
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack(
decoder_input, cache.get("encoder_outputs"), self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache)
logits = self.embedding_softmax_layer.linear(decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
return symbols_to_logits_fn
def predict(self, encoder_outputs, encoder_decoder_attention_bias):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
} for layer in range(self.params["num_hidden_layers"])}
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params["vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
class LayerNormalization(tf.layers.Layer):
"""Applies layer normalization."""
def __init__(self, hidden_size):
super(LayerNormalization, self).__init__()
self.hidden_size = hidden_size
def build(self, _):
self.scale = tf.get_variable("layer_norm_scale", [self.hidden_size],
initializer=tf.ones_initializer())
self.bias = tf.get_variable("layer_norm_bias", [self.hidden_size],
initializer=tf.zeros_initializer())
self.built = True
def call(self, x, epsilon=1e-6):
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
return norm_x * self.scale + self.bias
class PrePostProcessingWrapper(object):
"""Wrapper class that applies layer pre-processing and post-processing."""
def __init__(self, layer, params, train):
self.layer = layer
self.postprocess_dropout = params["layer_postprocess_dropout"]
self.train = train
# Create normalization layer
self.layer_norm = LayerNormalization(params["hidden_size"])
def __call__(self, x, *args, **kwargs):
# Preprocessing: apply layer normalization
y = self.layer_norm(x)
# Get layer output
y = self.layer(y, *args, **kwargs)
# Postprocessing: apply dropout and residual connection
if self.train:
y = tf.nn.dropout(y, 1 - self.postprocess_dropout)
return x + y
class EncoderStack(tf.layers.Layer):
"""Transformer encoder stack.
The encoder stack is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self, params, train):
super(EncoderStack, self).__init__()
self.layers = []
for _ in range(params["num_hidden_layers"]):
# Create sublayers for each layer.
self_attention_layer = attention_layer.SelfAttention(
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
# Create final layer normalization layer.
self.output_normalization = LayerNormalization(params["hidden_size"])
def call(self, encoder_inputs, attention_bias, inputs_padding):
"""Return the output of the encoder layer stacks.
Args:
encoder_inputs: tensor with shape [batch_size, input_length, hidden_size]
attention_bias: bias for the encoder self-attention layer.
[batch_size, 1, 1, input_length]
inputs_padding: P
Returns:
Output of encoder layer stack.
float32 tensor with shape [batch_size, input_length, hidden_size]
"""
for n, layer in enumerate(self.layers):
# Run inputs through the sublayers.
self_attention_layer = layer[0]
feed_forward_network = layer[1]
with tf.variable_scope("layer_%d" % n):
with tf.variable_scope("self_attention"):
encoder_inputs = self_attention_layer(encoder_inputs, attention_bias)
with tf.variable_scope("ffn"):
encoder_inputs = feed_forward_network(encoder_inputs, inputs_padding)
return self.output_normalization(encoder_inputs)
class DecoderStack(tf.layers.Layer):
"""Transformer decoder stack.
Like the encoder stack, the decoder stack is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self, params, train):
super(DecoderStack, self).__init__()
self.layers = []
for _ in range(params["num_hidden_layers"]):
self_attention_layer = attention_layer.SelfAttention(
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
enc_dec_attention_layer = attention_layer.Attention(
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(enc_dec_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
self.output_normalization = LayerNormalization(params["hidden_size"])
def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias, cache=None):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
decoder_self_attention_bias: bias for decoder self-attention layer.
[1, 1, target_len, target_length]
attention_bias: bias for encoder-decoder attention layer.
[batch_size, 1, 1, input_length]
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
...}
Returns:
Output of decoder layer stack.
float32 tensor with shape [batch_size, target_length, hidden_size]
"""
for n, layer in enumerate(self.layers):
self_attention_layer = layer[0]
enc_dec_attention_layer = layer[1]
feed_forward_network = layer[2]
# Run inputs through the sublayers.
layer_name = "layer_%d" % n
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope("self_attention"):
decoder_inputs = self_attention_layer(
decoder_inputs, decoder_self_attention_bias, cache=layer_cache)
with tf.variable_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer(
decoder_inputs, encoder_outputs, attention_bias)
with tf.variable_scope("ffn"):
decoder_inputs = feed_forward_network(decoder_inputs)
return self.output_normalization(decoder_inputs)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train and evaluate the Transformer model.
See README for description of setting the training schedule and evaluating the
BLEU score.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import schedule
from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
PARAMS_MAP = {
"tiny": model_params.TINY_PARAMS,
"base": model_params.BASE_PARAMS,
"big": model_params.BIG_PARAMS,
}
DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9)
BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name.
TENSORS_TO_LOG = {
"learning_rate": "model/get_train_op/learning_rate/learning_rate",
"cross_entropy_loss": "model/cross_entropy"}
def model_fn(features, labels, mode, params):
"""Defines how to train, evaluate and predict from the transformer model."""
with tf.variable_scope("model"):
inputs, targets = features, labels
# Create model and get output logits.
model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)
logits = model(inputs, targets)
# When in prediction mode, the labels/targets is None. The model output
# is the prediction
if mode == tf.estimator.ModeKeys.PREDICT:
if params["use_tpu"]:
raise NotImplementedError("Prediction is not yet supported on TPUs.")
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=logits,
export_outputs={
"translate": tf.estimator.export.PredictOutput(logits)
})
# Explicitly set the shape of the logits for XLA (TPU). This is needed
# because the logits are passed back to the host VM CPU for metric
# evaluation, and the shape of [?, ?, vocab_size] is too vague. However
# it is known from Transformer that the first two dimensions of logits
# are the dimensions of targets. Note that the ambiguous shape of logits is
# not a problem when computing xentropy, because padded_cross_entropy_loss
# resolves the shape on the TPU.
logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])
# Calculate model loss.
# xentropy contains the cross entropy loss of every nonpadding token in the
# targets.
xentropy, weights = metrics.padded_cross_entropy_loss(
logits, targets, params["label_smoothing"], params["vocab_size"])
loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
# Save loss as named tensor that will be logged with the logging hook.
tf.identity(loss, "cross_entropy")
if mode == tf.estimator.ModeKeys.EVAL:
if params["use_tpu"]:
# host call functions should only have tensors as arguments.
# This lambda pre-populates params so that metric_fn is
# TPUEstimator compliant.
metric_fn = lambda logits, labels: (
metrics.get_eval_metrics(logits, labels, params=params))
eval_metrics = (metric_fn, [logits, labels])
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
eval_metrics=eval_metrics)
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
else:
train_op, metric_dict = get_train_op_and_metrics(loss, params)
# Epochs can be quite long. This gives some intermediate information
# in TensorBoard.
metric_dict["minibatch_loss"] = loss
if params["use_tpu"]:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op,
host_call=tpu_util.construct_scalar_host_call(
metric_dict=metric_dict, model_dir=params["model_dir"],
prefix="training/")
)
record_scalars(metric_dict)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def record_scalars(metric_dict):
for key, value in metric_dict.items():
tf.contrib.summary.scalar(name=key, tensor=value)
def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
"""Calculate learning rate with linear warmup and rsqrt decay."""
with tf.name_scope("learning_rate"):
warmup_steps = tf.to_float(learning_rate_warmup_steps)
step = tf.to_float(tf.train.get_or_create_global_step())
learning_rate *= (hidden_size ** -0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, step / warmup_steps)
# Apply rsqrt decay
learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))
# Create a named tensor that will be logged using the logging hook.
# The full name includes variable and names scope. In this case, the name
# is model/get_train_op/learning_rate/learning_rate
tf.identity(learning_rate, "learning_rate")
return learning_rate
def get_train_op_and_metrics(loss, params):
"""Generate training op and metrics to save in TensorBoard."""
with tf.variable_scope("get_train_op"):
learning_rate = get_learning_rate(
learning_rate=params["learning_rate"],
hidden_size=params["hidden_size"],
learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
# than the TF core Adam optimizer.
optimizer = tf.contrib.opt.LazyAdamOptimizer(
learning_rate,
beta1=params["optimizer_adam_beta1"],
beta2=params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["use_tpu"] and params["tpu"] != tpu_util.LOCAL:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
# Calculate and apply gradients using LazyAdamOptimizer.
global_step = tf.train.get_global_step()
tvars = tf.trainable_variables()
gradients = optimizer.compute_gradients(
loss, tvars, colocate_gradients_with_ops=True)
minimize_op = optimizer.apply_gradients(
gradients, global_step=global_step, name="train")
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
train_metrics = {"learning_rate": learning_rate}
if not params["use_tpu"]:
# gradient norm is not included as a summary when running on TPU, as
# it can cause instability between the TPU and the host controller.
gradient_norm = tf.global_norm(list(zip(*gradients))[0])
train_metrics["global_norm/gradient_norm"] = gradient_norm
return train_op, train_metrics
def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
"""Translate file and report the cased and uncased bleu scores."""
# Create temporary file to store translation.
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp_filename = tmp.name
translate.translate_file(
estimator, subtokenizer, bleu_source, output_file=tmp_filename,
print_all_translations=False)
# Compute uncased and cased bleu scores.
uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
os.remove(tmp_filename)
return uncased_score, cased_score
def get_global_step(estimator):
"""Return estimator's last checkpoint."""
return int(estimator.latest_checkpoint().split("-")[-1])
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
"""Calculate and record the BLEU score."""
subtokenizer = tokenizer.Subtokenizer(vocab_file)
uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref)
tf.logging.info("Bleu score (uncased): %d", uncased_score)
tf.logging.info("Bleu score (cased): %d", cased_score)
return uncased_score, cased_score
def _validate_file(filepath):
"""Make sure that file exists."""
if not tf.gfile.Exists(filepath):
raise tf.errors.NotFoundError(None, None, "File %s not found." % filepath)
def run_loop(
estimator, schedule_manager, train_hooks=None, benchmark_logger=None,
bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file=None):
"""Train and evaluate model, and optionally compute model's BLEU score.
**Step vs. Epoch vs. Iteration**
Steps and epochs are canonical terms used in TensorFlow and general machine
learning. They are used to describe running a single process (train/eval):
- Step refers to running the process through a single or batch of examples.
- Epoch refers to running the process through an entire dataset.
E.g. training a dataset with 100 examples. The dataset is
divided into 20 batches with 5 examples per batch. A single training step
trains the model on one batch. After 20 training steps, the model will have
trained on every batch in the dataset, or, in other words, one epoch.
Meanwhile, iteration is used in this implementation to describe running
multiple processes (training and eval).
- A single iteration:
1. trains the model for a specific number of steps or epochs.
2. evaluates the model.
3. (if source and ref files are provided) compute BLEU score.
This function runs through multiple train+eval+bleu iterations.
Args:
estimator: tf.Estimator containing model to train.
schedule_manager: A schedule.Manager object to guide the run loop.
train_hooks: List of hooks to pass to the estimator during training.
benchmark_logger: a BenchmarkLogger object that logs evaluation data
bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped.
vocab_file: Path to vocab file that will be used to subtokenize bleu_source.
Raises:
ValueError: if both or none of single_iteration_train_steps and
single_iteration_train_epochs were defined.
NotFoundError: if the vocab file or bleu files don't exist.
"""
if bleu_source:
_validate_file(bleu_source)
if bleu_ref:
_validate_file(bleu_ref)
if vocab_file:
_validate_file(vocab_file)
evaluate_bleu = bleu_source is not None and bleu_ref is not None
if evaluate_bleu and schedule_manager.use_tpu:
raise ValueError("BLEU score can not be computed when training with a TPU, "
"as it requires estimator.predict which is not yet "
"supported.")
# Print details of training schedule.
tf.logging.info("Training schedule:")
tf.logging.info(
"\t1. Train for {}".format(schedule_manager.train_increment_str))
tf.logging.info("\t2. Evaluate model.")
if evaluate_bleu:
tf.logging.info("\t3. Compute BLEU score.")
if bleu_threshold is not None:
tf.logging.info("Repeat above steps until the BLEU score reaches %f" %
bleu_threshold)
if not evaluate_bleu or bleu_threshold is None:
tf.logging.info("Repeat above steps %d times." %
schedule_manager.train_eval_iterations)
if evaluate_bleu:
# Create summary writer to log bleu score (values can be displayed in
# Tensorboard).
bleu_writer = tf.summary.FileWriter(
os.path.join(estimator.model_dir, BLEU_DIR))
if bleu_threshold is not None:
# Change loop stopping condition if bleu_threshold is defined.
schedule_manager.train_eval_iterations = INF
# Loop training/evaluation/bleu cycles
for i in xrange(schedule_manager.train_eval_iterations):
tf.logging.info("Starting iteration %d" % (i + 1))
# Train the model for single_iteration_train_steps or until the input fn
# runs out of examples (if single_iteration_train_steps is None).
estimator.train(
dataset.train_input_fn,
steps=schedule_manager.single_iteration_train_steps,
hooks=train_hooks)
eval_results = estimator.evaluate(
input_fn=dataset.eval_input_fn,
steps=schedule_manager.single_iteration_eval_steps)
tf.logging.info("Evaluation results (iter %d/%d):" %
(i + 1, schedule_manager.train_eval_iterations))
tf.logging.info(eval_results)
benchmark_logger.log_evaluation_result(eval_results)
# The results from estimator.evaluate() are measured on an approximate
# translation, which utilize the target golden values provided. The actual
# bleu score must be computed using the estimator.predict() path, which
# outputs translations that are not based on golden values. The translations
# are compared to reference file to get the actual bleu score.
if evaluate_bleu:
uncased_score, cased_score = evaluate_and_log_bleu(
estimator, bleu_source, bleu_ref, vocab_file)
# Write actual bleu scores using summary writer and benchmark logger
global_step = get_global_step(estimator)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, global_step)
bleu_writer.flush()
benchmark_logger.log_metric(
"bleu_uncased", uncased_score, global_step=global_step)
benchmark_logger.log_metric(
"bleu_cased", cased_score, global_step=global_step)
# Stop training if bleu stopping threshold is met.
if model_helpers.past_stop_threshold(bleu_threshold, uncased_score):
bleu_writer.close()
break
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.).
flags_core.define_base()
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
intra_op=False,
synthetic_data=True,
max_train_steps=False,
dtype=False,
all_reduce_alg=True
)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
# Set flags from the flags_core module as "key flags" so they're listed when
# the '-h' flag is used. Without this line, the flags defined above are
# only shown in the full `--helpful` help text.
flags.adopt_module_key_flags(flags_core)
# Add transformer-specific flags
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=PARAMS_MAP.keys(),
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_bool(
name="static_batch", default=False,
help=flags_core.help_wrap(
"Whether the batches in the dataset should have static shapes. In "
"general, this setting should be False. Dynamic shapes allow the "
"inputs to be grouped so that the number of padding tokens is "
"minimized, and helps model training. In cases where the input shape "
"must be static (e.g. running on TPU), this setting will be ignored "
"and static batching will always be used."))
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
name="train_steps", short_name="ts", default=None,
help=flags_core.help_wrap("The number of steps used to train."))
flags.DEFINE_integer(
name="steps_between_evals", short_name="sbe", default=1000,
help=flags_core.help_wrap(
"The Number of training steps to run between evaluations. This is "
"used if --train_steps is defined."))
# BLEU score computation
flags.DEFINE_string(
name="bleu_source", short_name="bls", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. Both --bleu_source and --bleu_ref must be set. "
"Use the flag --stop_threshold to stop the script based on the "
"uncased BLEU score."))
flags.DEFINE_string(
name="bleu_ref", short_name="blr", default=None,
help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the "
"official BLEU score. Both --bleu_source and --bleu_ref must be set. "
"Use the flag --stop_threshold to stop the script based on the "
"uncased BLEU score."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=None,
help=flags_core.help_wrap(
"Path to subtoken vocabulary file. If data_download.py was used to "
"download and encode the training data, look in the data_dir to find "
"the vocab file."))
flags_core.set_defaults(data_dir="/tmp/translate_ende",
model_dir="/tmp/transformer_model",
batch_size=None,
train_epochs=None)
@flags.multi_flags_validator(
["train_epochs", "train_steps"],
message="Both --train_steps and --train_epochs were set. Only one may be "
"defined.")
def _check_train_limits(flag_dict):
return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None
@flags.multi_flags_validator(
["bleu_source", "bleu_ref"],
message="Both or neither --bleu_source and --bleu_ref must be defined.")
def _check_bleu_files(flags_dict):
return (flags_dict["bleu_source"] is None) == (
flags_dict["bleu_ref"] is None)
@flags.multi_flags_validator(
["bleu_source", "bleu_ref", "vocab_file"],
message="--vocab_file must be defined if --bleu_source and --bleu_ref "
"are defined.")
def _check_bleu_vocab_file(flags_dict):
if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
return flags_dict["vocab_file"] is not None
return True
@flags.multi_flags_validator(
["export_dir", "vocab_file"],
message="--vocab_file must be defined if --export_dir is set.")
def _check_export_vocab_file(flags_dict):
if flags_dict["export_dir"]:
return flags_dict["vocab_file"] is not None
return True
flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])
def construct_estimator(flags_obj, params, schedule_manager):
"""Construct an estimator from either Estimator or TPUEstimator.
Args:
flags_obj: The FLAGS object parsed from command line.
params: A dict of run specific parameters.
schedule_manager: A schedule.Manager object containing the run schedule.
Returns:
An estimator object to be used for training and eval.
"""
if not params["use_tpu"]:
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
return tf.estimator.Estimator(
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params,
config=tf.estimator.RunConfig(train_distribute=distribution_strategy))
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
tpu=flags_obj.tpu,
zone=flags_obj.tpu_zone,
project=flags_obj.tpu_gcp_project
)
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=schedule_manager.single_iteration_train_steps,
num_shards=flags_obj.num_tpu_shards)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=flags_obj.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config)
return tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
train_batch_size=schedule_manager.batch_size,
eval_batch_size=schedule_manager.batch_size,
params={
# TPUEstimator needs to populate batch_size itself due to sharding.
key: value for key, value in params.items() if key != "batch_size"},
config=run_config)
def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
Args:
flags_obj: Object containing parsed flag values.
"""
num_gpus = flags_core.get_num_gpus(flags_obj)
# Add flag-defined parameters to params object
params = PARAMS_MAP[flags_obj.param_set]
if num_gpus > 1:
if flags_obj.param_set == "big":
params = model_params.BIG_MULTI_GPU_PARAMS
elif flags_obj.param_set == "base":
params = model_params.BASE_MULTI_GPU_PARAMS
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["num_parallel_calls"] = flags_obj.num_parallel_calls
params["tpu"] = flags_obj.tpu
params["use_tpu"] = bool(flags_obj.tpu) # was a tpu specified.
params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
params["allow_ffn_pad"] = not params["use_tpu"]
params["use_synthetic_data"] = flags_obj.use_synthetic_data
# Set batch size parameter, which depends on the availability of
# TPU and GPU, and distribution settings.
params["batch_size"] = (flags_obj.batch_size or (
params["default_batch_size_tpu"] if params["use_tpu"]
else params["default_batch_size"]))
if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size(
params["batch_size"], num_gpus)
schedule_manager = schedule.Manager(
train_steps=flags_obj.train_steps,
steps_between_evals=flags_obj.steps_between_evals,
train_epochs=flags_obj.train_epochs,
epochs_between_evals=flags_obj.epochs_between_evals,
default_train_epochs=DEFAULT_TRAIN_EPOCHS,
batch_size=params["batch_size"],
max_length=params["max_length"],
use_tpu=params["use_tpu"],
num_tpu_shards=flags_obj.num_tpu_shards
)
params["repeat_dataset"] = schedule_manager.repeat_dataset
model_helpers.apply_clean(flags.FLAGS)
# Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
model_dir=flags_obj.model_dir,
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=schedule_manager.batch_size, # for ExamplesPerSecondHook
use_tpu=params["use_tpu"] # Not all hooks can run with TPUs
)
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
run_params=params,
test_id=flags_obj.benchmark_test_id)
# Train and evaluate transformer model
estimator = construct_estimator(flags_obj, params, schedule_manager)
run_loop(
estimator=estimator,
# Training arguments
schedule_manager=schedule_manager,
train_hooks=train_hooks,
benchmark_logger=benchmark_logger,
# BLEU calculation arguments
bleu_source=flags_obj.bleu_source,
bleu_ref=flags_obj.bleu_ref,
bleu_threshold=flags_obj.stop_threshold,
vocab_file=flags_obj.vocab_file)
if flags_obj.export_dir and not params["use_tpu"]:
serving_input_fn = export.build_tensor_serving_input_receiver_fn(
shape=[None], dtype=tf.int64, batch_size=None)
# Export saved model, and save the vocab file as an extra asset. The vocab
# file is saved to allow consistent input encoding and output decoding.
# (See the "Export trained model" section in the README for an example of
# how to use the vocab file.)
# Since the model itself does not use the vocab file, this file is saved as
# an extra asset rather than a core asset.
estimator.export_savedmodel(
flags_obj.export_dir, serving_input_fn,
assets_extra={"vocab.txt": flags_obj.vocab_file},
strip_default_attrs=True)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_transformer(flags.FLAGS)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_transformer_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Translate text or files using trained transformer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4
_ALPHA = 0.6
def _get_sorted_inputs(filename):
"""Read and sort lines from the file sorted by decreasing length.
Args:
filename: String name of file to read inputs from.
Returns:
Sorted list of inputs, and dictionary mapping original index->sorted index
of each element.
"""
with tf.gfile.Open(filename) as f:
records = f.read().split("\n")
inputs = [record.strip() for record in records]
if not inputs[-1]:
inputs.pop()
input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
sorted_input_lens = sorted(input_lens, key=lambda x: x[1], reverse=True)
sorted_inputs = [None] * len(sorted_input_lens)
sorted_keys = [0] * len(sorted_input_lens)
for i, (index, _) in enumerate(sorted_input_lens):
sorted_inputs[i] = inputs[index]
sorted_keys[index] = i
return sorted_inputs, sorted_keys
def _encode_and_add_eos(line, subtokenizer):
"""Encode line with subtokenizer, and add EOS id to the end."""
return subtokenizer.encode(line) + [tokenizer.EOS_ID]
def _trim_and_decode(ids, subtokenizer):
"""Trim EOS and PAD tokens from ids, and decode to return a string."""
try:
index = list(ids).index(tokenizer.EOS_ID)
return subtokenizer.decode(ids[:index])
except ValueError: # No EOS found in sequence
return subtokenizer.decode(ids)
def translate_file(
estimator, subtokenizer, input_file, output_file=None,
print_all_translations=True):
"""Translate lines in file, and save to output file if specified.
Args:
estimator: tf.Estimator used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and
translated lines.
input_file: file containing lines to translate
output_file: file that stores the generated translations.
print_all_translations: If true, all translations are printed to stdout.
Raises:
ValueError: if output file is invalid.
"""
batch_size = _DECODE_BATCH_SIZE
# Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order.
sorted_inputs, sorted_keys = _get_sorted_inputs(input_file)
num_decode_batches = (len(sorted_inputs) - 1) // batch_size + 1
def input_generator():
"""Yield encoded strings from sorted_inputs."""
for i, line in enumerate(sorted_inputs):
if i % batch_size == 0:
batch_num = (i // batch_size) + 1
tf.logging.info("Decoding batch %d out of %d." %
(batch_num, num_decode_batches))
yield _encode_and_add_eos(line, subtokenizer)
def input_fn():
"""Created batched dataset of encoded inputs."""
ds = tf.data.Dataset.from_generator(
input_generator, tf.int64, tf.TensorShape([None]))
ds = ds.padded_batch(batch_size, [None])
return ds
translations = []
for i, prediction in enumerate(estimator.predict(input_fn)):
translation = _trim_and_decode(prediction["outputs"], subtokenizer)
translations.append(translation)
if print_all_translations:
tf.logging.info("Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[i], translation))
# Write translations in the order they appeared in the original file.
if output_file is not None:
if tf.gfile.IsDirectory(output_file):
raise ValueError("File output is a directory, will not save outputs to "
"file.")
tf.logging.info("Writing to file %s" % output_file)
with tf.gfile.Open(output_file, "w") as f:
for i in sorted_keys:
f.write("%s\n" % translations[i])
def translate_text(estimator, subtokenizer, txt):
"""Translate a single string."""
encoded_txt = _encode_and_add_eos(txt, subtokenizer)
def input_fn():
ds = tf.data.Dataset.from_tensors(encoded_txt)
ds = ds.batch(_DECODE_BATCH_SIZE)
return ds
predictions = estimator.predict(input_fn)
translation = next(predictions)["outputs"]
translation = _trim_and_decode(translation, subtokenizer)
tf.logging.info("Translation of \"%s\": \"%s\"" % (txt, translation))
def main(unused_argv):
from official.transformer import transformer_main
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.text is None and FLAGS.file is None:
tf.logging.warn("Nothing to translate. Make sure to call this script using "
"flags --text or --file.")
return
subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)
# Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set]
params["beam_size"] = _BEAM_SIZE
params["alpha"] = _ALPHA
params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
params["batch_size"] = _DECODE_BATCH_SIZE
estimator = tf.estimator.Estimator(
model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
params=params)
if FLAGS.text is not None:
tf.logging.info("Translating text: %s" % FLAGS.text)
translate_text(estimator, subtokenizer, FLAGS.text)
if FLAGS.file is not None:
input_file = os.path.abspath(FLAGS.file)
tf.logging.info("Translating file: %s" % input_file)
if not tf.gfile.Exists(FLAGS.file):
raise ValueError("File does not exist: %s" % input_file)
output_file = None
if FLAGS.file_out is not None:
output_file = os.path.abspath(FLAGS.file_out)
tf.logging.info("File output specified: %s" % output_file)
translate_file(estimator, subtokenizer, input_file, output_file)
def define_translate_flags():
"""Define flags used for translation script."""
# Model flags
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp/transformer_model",
help=flags_core.help_wrap(
"Directory containing Transformer model checkpoints."))
flags.DEFINE_enum(
name="param_set", short_name="mp", default="big",
enum_values=["base", "big"],
help=flags_core.help_wrap(
"Parameter set to use when creating and training the model. The "
"parameters define the input shape (batch size and max length), "
"model configuration (size of embedding, # of hidden layers, etc.), "
"and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=None,
help=flags_core.help_wrap(
"Path to subtoken vocabulary file. If data_download.py was used to "
"download and encode the training data, look in the data_dir to find "
"the vocab file."))
flags.mark_flag_as_required("vocab_file")
flags.DEFINE_string(
name="text", default=None,
help=flags_core.help_wrap(
"Text to translate. Output will be printed to console."))
flags.DEFINE_string(
name="file", default=None,
help=flags_core.help_wrap(
"File containing text to translate. Translation will be printed to "
"console and, if --file_out is provided, saved to an output file."))
flags.DEFINE_string(
name="file_out", default=None,
help=flags_core.help_wrap(
"If --file flag is specified, save translation to this file."))
if __name__ == "__main__":
define_translate_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.
Two things to note in the pipeline:
1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{"inputs": [variable length array of integers],
"targets": [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between
"inputs" and "targets" length). Each group is then batched such that:
group_batch_size * length <= batch_size.
Another way to view batch_size is the maximum number of tokens in each batch.
Once batched, each element in the dataset will have the shape:
{"inputs": [group_batch_size, padded_input_length],
"targets": [group_batch_size, padded_target_length]}
Lengths are padded to the longest "inputs" or "targets" sequence in the batch
(padded_input_length and padded_target_length can be different).
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
2. Shuffling
While training, the dataset is shuffled in two places in the code. The first
is the list of training files. Second, while reading records using
`parallel_interleave`, the `sloppy` argument is used to generate randomness
in the order of the examples.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import tensorflow as tf
from official.utils.misc import model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER = 8 * 1000 * 1000
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1
def _load_records(filename):
"""Read file and return a dataset of tf.Examples."""
return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)
def _parse_example(serialized_example):
"""Return inputs and targets Tensors from a serialized tf.Example."""
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64)
}
parsed = tf.parse_single_example(serialized_example, data_fields)
inputs = tf.sparse_tensor_to_dense(parsed["inputs"])
targets = tf.sparse_tensor_to_dense(parsed["targets"])
return inputs, targets
def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example):
"""Returns the maximum length between the example inputs and targets."""
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
return length
def _create_min_max_boundaries(
max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE):
"""Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
returned values will be:
buckets_min = [0, 4, 8, 16, 24]
buckets_max = [4, 8, 16, 24, 25]
Args:
max_length: The maximum length of example in dataset.
min_boundary: Minimum length in boundary.
boundary_scale: Amount to scale consecutive boundaries in the list.
Returns:
min and max boundary lists
"""
# Create bucket boundaries list by scaling the previous boundary or adding 1
# (to ensure increasing boundary sizes).
bucket_boundaries = []
x = min_boundary
while x < max_length:
bucket_boundaries.append(x)
x = max(x + 1, int(x * boundary_scale))
# Create min and max boundary lists from the initial list.
buckets_min = [0] + bucket_boundaries
buckets_max = bucket_boundaries + [max_length + 1]
return buckets_min, buckets_max
def _batch_examples(dataset, batch_size, max_length):
"""Group examples by similar lengths, and return batched dataset.
Each batch of similar-length examples are padded to the same length, and may
have different number of elements in each batch, such that:
group_batch_size * padded_length <= batch_size.
This decreases the number of padding tokens per batch, which improves the
training speed.
Args:
dataset: Dataset of unbatched examples.
batch_size: Max number of tokens per batch of examples.
max_length: Max number of tokens in an example input or target sequence.
Returns:
Dataset of batched examples with similar lengths.
"""
# Get min and max boundary lists for each example. These are used to calculate
# the `bucket_id`, which is the index at which:
# buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
# Note that using both min and max lists improves the performance.
buckets_min, buckets_max = _create_min_max_boundaries(max_length)
# Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes = [batch_size // x for x in buckets_max]
# bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
def example_to_bucket_id(example_input, example_target):
"""Return int64 bucket id for this example, calculated based on length."""
seq_length = _get_example_length((example_input, example_target))
# TODO: investigate whether removing code branching improves performance.
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length),
tf.less(seq_length, buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
def window_size_fn(bucket_id):
"""Return number of examples to be grouped when given a bucket id."""
return bucket_batch_sizes[bucket_id]
def batching_fn(bucket_id, grouped_dataset):
"""Batch and add padding to a dataset of elements with similar lengths."""
bucket_batch_size = window_size_fn(bucket_id)
# Batch the dataset and add padding so that all input sequences in the
# examples have the same length, and all target sequences have the same
# lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.contrib.data.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
window_size_func=window_size_fn))
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
static_batch=False):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples
max_length: Maximum number of tokens per example
num_parallel_calls: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
static_batch: Whether the batches in the dataset should have static shapes.
If True, the input is batched so that every batch has the
shape [batch_size // max_length, max_length]. If False, the input is
grouped by length, and batched so that batches may have different
shapes [N, M], where:
N * M <= batch_size
M <= max_length
In general, this setting should be False. Dynamic shapes allow the inputs
to be grouped so that the number of padding tokens is minimized, and helps
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
Returns:
tf.data.Dataset object containing examples loaded from the files.
"""
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=num_parallel_calls)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
if static_batch:
dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder(
batch_size // max_length, ([max_length], [max_length])))
else:
# Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length)
dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset
def _generate_synthetic_data(params):
"""Create synthetic data based on the parameter batch size."""
batch = length = int(math.sqrt(params["batch_size"]))
return model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([batch, length]),
input_value=1,
input_dtype=tf.int32,
label_shape=tf.TensorShape([batch, length]),
label_value=1,
label_dtype=tf.int32,
)
def train_input_fn(params):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(params["data_dir"] or "", "*train*")
if params["use_synthetic_data"]:
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"])
def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
if params["use_synthetic_data"]:
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1,
static_batch=params["static_batch"])
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
- Padded loss, accuracy, and negative log perplexity. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
- BLEU approximation. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
- ROUGE score. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
def _pad_tensors_to_same_length(x, y):
"""Pad x and y so that the results have the same length (second dimension)."""
with tf.name_scope("pad_to_same_length"):
x_length = tf.shape(x)[1]
y_length = tf.shape(y)[1]
max_length = tf.maximum(x_length, y_length)
x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
return x, y
def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
"""Calculate cross entropy loss while ignoring padding.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch_size, length_labels]
smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary
Returns:
Returns the cross entropy loss and weight tensors: float32 tensors with
shape [batch_size, max(length_logits, length_labels)]
"""
with tf.name_scope("loss", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
# Calculate smoothing cross entropy
with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
confidence = 1.0 - smoothing
low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1)
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=soft_targets)
# Calculate the best (lowest) possible value of cross entropy, and
# subtract from the cross entropy loss.
normalizing_constant = -(
confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) *
low_confidence * tf.log(low_confidence + 1e-20))
xentropy -= normalizing_constant
weights = tf.to_float(tf.not_equal(labels, 0))
return xentropy * weights, weights
def _convert_to_eval_metric(metric_fn):
"""Wrap a metric fn that returns scores and weights as an eval metric fn.
The input metric_fn returns values for the current batch. The wrapper
aggregates the return values collected over all of the batches evaluated.
Args:
metric_fn: function that returns scores and weights for the current batch's
logits and predicted labels.
Returns:
function that aggregates the scores and weights from metric_fn.
"""
def problem_metric_fn(*args):
"""Returns an aggregation of the metric_fn's returned values."""
(scores, weights) = metric_fn(*args)
# The tf.metrics.mean function assures correct aggregation.
return tf.metrics.mean(scores, weights)
return problem_metric_fn
def get_eval_metrics(logits, labels, params):
"""Return dictionary of model evaluation metrics."""
metrics = {
"accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels),
"accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)(
logits, labels),
"accuracy_per_sequence": _convert_to_eval_metric(
padded_sequence_accuracy)(logits, labels),
"neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
logits, labels, params["vocab_size"]),
}
if not params["use_tpu"]:
# TPU does not support tf.py_func
metrics.update({
"approx_bleu_score": _convert_to_eval_metric(
bleu_score)(logits, labels),
"rouge_2_fscore": _convert_to_eval_metric(
rouge_2_fscore)(logits, labels),
"rouge_L_fscore": _convert_to_eval_metric(
rouge_l_fscore)(logits, labels),
})
# Prefix each of the metric names with "metrics/". This allows the metric
# graphs to display under the "metrics" category in TensorBoard.
metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
return metrics
def padded_accuracy(logits, labels):
"""Percentage of times that predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
outputs = tf.to_int32(tf.argmax(logits, axis=-1))
padded_labels = tf.to_int32(labels)
return tf.to_float(tf.equal(outputs, padded_labels)), weights
def padded_accuracy_topk(logits, labels, k):
"""Percentage of times that top-k predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy_topk", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
effective_k = tf.minimum(k, tf.shape(logits)[-1])
_, outputs = tf.nn.top_k(logits, k=effective_k)
outputs = tf.to_int32(outputs)
padded_labels = tf.to_int32(labels)
padded_labels = tf.expand_dims(padded_labels, axis=-1)
padded_labels += tf.zeros_like(outputs) # Pad to same shape.
same = tf.to_float(tf.equal(outputs, padded_labels))
same_topk = tf.reduce_sum(same, axis=-1)
return same_topk, weights
def padded_accuracy_top5(logits, labels):
return padded_accuracy_topk(logits, labels, 5)
def padded_sequence_accuracy(logits, labels):
"""Percentage of times that predictions matches labels everywhere (non-0)."""
with tf.variable_scope("padded_sequence_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
outputs = tf.to_int32(tf.argmax(logits, axis=-1))
padded_labels = tf.to_int32(labels)
not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights
axis = list(range(1, len(outputs.get_shape())))
correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
return correct_seq, tf.constant(1.0)
def padded_neg_log_perplexity(logits, labels, vocab_size):
"""Average log-perplexity excluding padding 0s. No smoothing."""
num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
return -num, den
def bleu_score(logits, labels):
"""Approximate BLEU score computation between labels and predictions.
An approximate BLEU scoring method since we do not glue word pieces or
decode the ids and tokenize the output. By default, we use ngram order of 4
and use brevity penalty. Also, this does not have beam search.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch-size, length_labels]
Returns:
bleu: int, approx bleu score
"""
predictions = tf.to_int32(tf.argmax(logits, axis=-1))
# TODO: Look into removing use of py_func
bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
return bleu, tf.constant(1.0)
def _get_ngrams_with_counter(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in xrange(1, max_order + 1):
for i in xrange(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
precisions = [0] * max_order
smooth = 1.0
for i in xrange(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
def rouge_2_fscore(logits, labels):
"""ROUGE-2 F1 score computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
logits: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge2_fscore: approx rouge-2 f1 score.
"""
predictions = tf.to_int32(tf.argmax(logits, axis=-1))
# TODO: Look into removing use of py_func
rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
return rouge_2_f_score, tf.constant(1.0)
def _get_ngrams(n, text):
"""Calculates n-grams.
Args:
n: which n-grams to calculate
text: An array of tokens
Returns:
A set of n-grams
"""
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def rouge_n(eval_sentences, ref_sentences, n=2):
"""Computes ROUGE-N f1 score of two text collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Args:
eval_sentences: Predicted sentences.
ref_sentences: Sentences from the reference set
n: Size of ngram. Defaults to 2.
Returns:
f1 score for ROUGE-N
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
eval_ngrams = _get_ngrams(n, eval_sentence)
ref_ngrams = _get_ngrams(n, ref_sentence)
ref_count = len(ref_ngrams)
eval_count = len(eval_ngrams)
# Count the overlapping ngrams between evaluated and reference
overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
overlapping_count = len(overlapping_ngrams)
# Handle edge case. This isn't mathematically correct, but it's good enough
if eval_count == 0:
precision = 0.0
else:
precision = float(overlapping_count) / eval_count
if ref_count == 0:
recall = 0.0
else:
recall = float(overlapping_count) / ref_count
f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))
# return overlapping_count / reference_count
return np.mean(f1_scores, dtype=np.float32)
def rouge_l_fscore(predictions, labels):
"""ROUGE scores computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
predictions: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge_l_fscore: approx rouge-l f1 score.
"""
outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels),
tf.float32)
return rouge_l_f_score, tf.constant(1.0)
def rouge_l_sentence_level(eval_sentences, ref_sentences):
"""Computes ROUGE-L (sentence level) of two collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Calculated according to:
R_lcs = LCS(X,Y)/m
P_lcs = LCS(X,Y)/n
F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
where:
X = reference summary
Y = Candidate summary
m = length of reference summary
n = length of candidate summary
Args:
eval_sentences: The sentences that have been picked by the summarizer
ref_sentences: The sentences from the reference set
Returns:
A float: F_lcs
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
m = float(len(ref_sentence))
n = float(len(eval_sentence))
lcs = _len_lcs(eval_sentence, ref_sentence)
f1_scores.append(_f_lcs(lcs, m, n))
return np.mean(f1_scores, dtype=np.float32)
def _len_lcs(x, y):
"""Returns the length of the Longest Common Subsequence between two seqs.
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: sequence of words
y: sequence of words
Returns
integer: Length of LCS between x and y
"""
table = _lcs(x, y)
n, m = len(x), len(y)
return table[n, m]
def _lcs(x, y):
"""Computes the length of the LCS between two seqs.
The implementation below uses a DP programming algorithm and runs
in O(nm) time where n = len(x) and m = len(y).
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: collection of words
y: collection of words
Returns:
Table of dictionary of coord and len lcs
"""
n, m = len(x), len(y)
table = dict()
for i in range(n + 1):
for j in range(m + 1):
if i == 0 or j == 0:
table[i, j] = 0
elif x[i - 1] == y[j - 1]:
table[i, j] = table[i - 1, j - 1] + 1
else:
table[i, j] = max(table[i - 1, j], table[i, j - 1])
return table
def _f_lcs(llcs, m, n):
"""Computes the LCS-based F-measure score.
Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
rouge-working-note-v1.3.1.pdf
Args:
llcs: Length of LCS
m: number of words in reference summary
n: number of words in candidate summary
Returns:
Float. LCS-based F-measure score
"""
r_lcs = llcs / m
p_lcs = llcs / n
beta = p_lcs / (r_lcs + 1e-12)
num = (1 + (beta ** 2)) * r_lcs * p_lcs
denom = r_lcs + ((beta ** 2) * p_lcs)
f_lcs = num / (denom + 1e-12)
return f_lcs
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract training on a step or epoch basis."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
_TRAIN, _EVAL = tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
NUM_EXAMPLES = {
tf.estimator.ModeKeys.TRAIN: 4572160,
# # Examples that are too long are filtered out, thus the total is less
# # than the total number of lines.
# 2399123 + # news-commentary-v12.de-en
# 1920209 + # commoncrawl.de-en
# 270769, # europarl-v7.de-en
tf.estimator.ModeKeys.EVAL: 3000, # newstest2013
}
class Manager(object):
"""Container for convenience functions to abstract step or epoch basis.
Transformer allows users to specify an epoch basis (generally recommended for
full training) or a number of steps basis (convenient since epochs are rather
large). TPUs furthermore require a step basis; however epochs are the norm in
the machine learning community and it is desirable to allow users to specify
epochs even when running with TPUS which requires behind the scenes
conversions.
This container simply groups what are largely mundane checks and conversions
rather than interspersing them throughout the run loop code.
"""
def __init__(self, train_steps, steps_between_evals, train_epochs,
epochs_between_evals, default_train_epochs, batch_size,
max_length, use_tpu=False, num_tpu_shards=8):
if train_steps and train_epochs:
raise ValueError("Both train_steps or train_epochs were be defined.")
# Determine training schedule based on flags.
if train_steps:
self.train_eval_iterations = train_steps // steps_between_evals
self._single_iteration_train_steps = steps_between_evals
self._single_iteration_train_epochs = None
else:
train_epochs = train_epochs or default_train_epochs
self.train_eval_iterations = train_epochs // epochs_between_evals
self._single_iteration_train_steps = None
self._single_iteration_train_epochs = epochs_between_evals
self.max_length = max_length
self.batch_size = batch_size
self.use_tpu = use_tpu
self.num_tpu_shards = num_tpu_shards
if self.use_tpu:
assert (self.batch_size // self.max_length) % self.num_tpu_shards == 0
@property
def single_iteration_train_steps(self):
if self._single_iteration_train_steps or not self.use_tpu:
return self._single_iteration_train_steps
return self.epochs_to_steps(
num_epochs=self._single_iteration_train_epochs, mode=_TRAIN)
@property
def single_iteration_eval_steps(self):
if not self.use_tpu:
return None
return self.epochs_to_steps(num_epochs=1, mode=_EVAL)
@property
def train_increment_str(self):
if self._single_iteration_train_steps:
return "{} steps.".format(self._single_iteration_train_steps)
if not self.use_tpu:
return "{} epochs.".format(self._single_iteration_train_epochs)
return "~{} epochs. ({} steps)".format(
self._single_iteration_train_epochs,
self.single_iteration_train_steps)
@property
def repeat_dataset(self):
if (self._single_iteration_train_epochs is None and
self._single_iteration_train_steps > NUM_EXAMPLES[_TRAIN]):
return math.ceil(self._single_iteration_train_steps /
NUM_EXAMPLES[_TRAIN])
return self._single_iteration_train_epochs
def epochs_to_steps(self, num_epochs, mode):
"""Converts a number of epochs to a number of training steps.
TPU only: This function assumes that static_batch is True.
TPU can not tolerate an OutOfRange error from a dataset. As a result the
number of examples to be processed must be known ahead of time. TPUs also
do not allow partial batches, so this function rounds down.
Args:
num_epochs: An integer of the number of epochs to convert to steps.
mode: The estimator ModeKey of the computation
Returns:
An integer of the number of equivalent steps rounded down.
"""
assert self.use_tpu, "epochs_to_steps should only be reached when using TPU"
total_num_tokens = NUM_EXAMPLES[mode] * self.max_length * num_epochs
return total_num_tokens // self.batch_size
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer's schedule manager."""
import tensorflow as tf
from official.transformer.utils import schedule
class ScheduleBaseTester(tf.test.TestCase):
def test_mutual_exclusivity(self):
with self.assertRaises(ValueError):
schedule.Manager(
train_steps=100, steps_between_evals=100, train_epochs=2,
epochs_between_evals=1, default_train_epochs=None, batch_size=2048,
max_length=256)
def test_step_basis(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256)
self.assertEqual(manager.single_iteration_train_steps, 100)
# Evaluation uses the full set
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256)
# For non-TPU, estimator relies on dataset exhausion
self.assertIsNone(manager.single_iteration_train_steps)
self.assertIsNone(manager.single_iteration_eval_steps)
self.assertEqual(manager.repeat_dataset, 2)
def test_step_basis_tpu(self):
manager = schedule.Manager(
train_steps=1000, steps_between_evals=100, train_epochs=None,
epochs_between_evals=None, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(manager.single_iteration_train_steps, 100)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertIsNone(manager.repeat_dataset)
def test_epoch_basis_tpu(self):
manager = schedule.Manager(
train_steps=None, steps_between_evals=None, train_epochs=10,
epochs_between_evals=2, default_train_epochs=None, batch_size=2048,
max_length=256, use_tpu=True)
self.assertEqual(
manager.single_iteration_train_steps,
schedule.NUM_EXAMPLES[tf.estimator.ModeKeys.TRAIN] * 2 // (2048 / 256)
)
# num_eval_examples / (batch_size / max_length) == 3000 / (2048 / 256)
self.assertEqual(manager.single_iteration_eval_steps, 375)
self.assertEqual(manager.repeat_dataset, 2)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Subtokenizer class to encode and decode strings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import sys
import unicodedata
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
PAD = "<pad>"
PAD_ID = 0
EOS = "<EOS>"
EOS_ID = 1
RESERVED_TOKENS = [PAD, EOS]
# Set of characters that will be used in the function _escape_token() (see func
# docstring for more details).
# This set is added to the alphabet list to ensure that all escaped tokens can
# be encoded.
_ESCAPE_CHARS = set(u"\\_u;0123456789")
# Regex for the function _unescape_token(), the inverse of _escape_token().
# This is used to find "\u", "\\", and "\###;" substrings in the token.
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_UNDEFINED_UNICODE = u"\u3013"
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
# min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size.
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
class Subtokenizer(object):
"""Encodes and decodes strings to/from integer IDs."""
def __init__(self, vocab_file, reserved_tokens=None):
"""Initializes class, creating a vocab file if data_files is provided."""
tf.logging.info("Initializing Subtokenizer from file %s." % vocab_file)
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens)
self.alphabet = _generate_alphabet_dict(self.subtoken_list)
self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list)
self.max_subtoken_length = 0
for subtoken in self.subtoken_list:
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
# Create cache to speed up subtokenization
self._cache_size = 2 ** 20
self._cache = [(None, None)] * self._cache_size
@staticmethod
def init_from_files(
vocab_file, files, target_vocab_size, threshold, min_count=None,
file_byte_limit=1e6, reserved_tokens=None):
"""Create subtoken vocabulary based on files, and save vocab to file.
Args:
vocab_file: String name of vocab file to store subtoken vocabulary.
files: List of file paths that will be used to generate vocabulary.
target_vocab_size: target vocabulary size to generate.
threshold: int threshold of vocabulary size to accept.
min_count: int minimum count to use for generating the vocabulary. The min
count is the minimum number of times a subtoken should appear in the
files before it is added to the vocabulary. If set to none, this value
is found using binary search.
file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that
will be drawn from the files.
reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list.
Returns:
Subtokenizer object
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if tf.gfile.Exists(vocab_file):
tf.logging.info("Vocab file already exists (%s)" % vocab_file)
else:
tf.logging.info("Begin steps to create subtoken vocabulary...")
token_counts = _count_tokens(files, file_byte_limit)
alphabet = _generate_alphabet_dict(token_counts)
subtoken_list = _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_vocab_size, threshold, min_count,
reserved_tokens)
tf.logging.info("Generated vocabulary with %d subtokens." %
len(subtoken_list))
_save_vocab_file(vocab_file, subtoken_list)
return Subtokenizer(vocab_file)
def encode(self, raw_string, add_eos=False):
"""Encodes a string into a list of int subtoken ids."""
ret = []
tokens = _split_string_to_tokens(_native_to_unicode(raw_string))
for token in tokens:
ret.extend(self._token_to_subtoken_ids(token))
if add_eos:
ret.append(EOS_ID)
return ret
def _token_to_subtoken_ids(self, token):
"""Encode a single token into a list of subtoken ids."""
cache_location = hash(token) % self._cache_size
cache_key, cache_value = self._cache[cache_location]
if cache_key == token:
return cache_value
ret = _split_token_to_subtokens(
_escape_token(token, self.alphabet), self.subtoken_to_id_dict,
self.max_subtoken_length)
ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret]
self._cache[cache_location] = (token, ret)
return ret
def decode(self, subtokens):
"""Converts list of int subtokens ids into a string."""
if isinstance(subtokens, np.ndarray):
# Note that list(subtokens) converts subtokens to a python list, but the
# items remain as np.int32. This converts both the array and its items.
subtokens = subtokens.tolist()
if not subtokens:
return ""
assert isinstance(subtokens, list) and isinstance(subtokens[0], int), (
"Subtokens argument passed into decode() must be a list of integers.")
return _unicode_to_native(
_join_tokens_to_string(self._subtoken_ids_to_tokens(subtokens)))
def _subtoken_ids_to_tokens(self, subtokens):
"""Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens = "".join([
self.subtoken_list[s] for s in subtokens
if s < len(self.subtoken_list)])
escaped_tokens = escaped_tokens.split("_")
# All tokens in the vocabulary list have been escaped (see _escape_token())
# so each token must be unescaped when decoding.
ret = []
for token in escaped_tokens:
if token:
ret.append(_unescape_token(token))
return ret
def _save_vocab_file(vocab_file, subtoken_list):
"""Save subtokens to file."""
with tf.gfile.Open(vocab_file, mode="w") as f:
for subtoken in subtoken_list:
f.write("'%s'\n" % _unicode_to_native(subtoken))
def _load_vocab_file(vocab_file, reserved_tokens=None):
"""Load vocabulary while ensuring reserved tokens are at the top."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
subtoken_list = []
with tf.gfile.Open(vocab_file, mode="r") as f:
for line in f:
subtoken = _native_to_unicode(line.strip())
subtoken = subtoken[1:-1] # Remove surrounding single-quotes
if subtoken in reserved_tokens:
continue
subtoken_list.append(_native_to_unicode(subtoken))
return reserved_tokens + subtoken_list
def _native_to_unicode(s):
"""Convert string to unicode (required in Python 2)."""
try: # Python 2
return s if isinstance(s, unicode) else s.decode("utf-8")
except NameError: # Python 3
return s
def _unicode_to_native(s):
"""Convert string from unicode to native format (required in Python 2)."""
try: # Python 2
return s.encode("utf-8") if isinstance(s, unicode) else s
except NameError: # Python 3
return s
def _split_string_to_tokens(text):
"""Splits text to a list of string tokens."""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
for pos in xrange(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
token_start = pos
final_token = text[token_start:]
ret.append(final_token)
return ret
def _join_tokens_to_string(tokens):
"""Join a list of string tokens into a single string."""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
def _escape_token(token, alphabet):
r"""Replace characters that aren't in the alphabet and append "_" to token.
Apply three transformations to the token:
1. Replace underline character "_" with "\u", and backslash "\" with "\\".
2. Replace characters outside of the alphabet with "\###;", where ### is the
character's Unicode code point.
3. Appends "_" to mark the end of a token.
Args:
token: unicode string to be escaped
alphabet: list of all known characters
Returns:
escaped string
"""
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
return u"".join(ret) + "_"
def _unescape_token(token):
r"""Replaces escaped characters in the token with their unescaped versions.
Applies inverse transformations as _escape_token():
1. Replace "\u" with "_", and "\\" with "\".
2. Replace "\###;" with the unicode character the ### refers to.
Args:
token: escaped string
Returns:
unescaped string
"""
def match(m):
r"""Returns replacement string for matched object.
Matched objects contain one of the strings that matches the regex pattern:
r"\\u|\\\\|\\([0-9]+);"
The strings can be '\u', '\\', or '\###;' (### is any digit number).
m.group(0) refers to the entire matched string ('\u', '\\', or '\###;').
m.group(1) refers to the first parenthesized subgroup ('###').
m.group(0) exists for all match objects, while m.group(1) exists only for
the string '\###;'.
This function looks to see if m.group(1) exists. If it doesn't, then the
matched string must be '\u' or '\\' . In this case, the corresponding
replacement ('_' and '\') are returned. Note that in python, a single
backslash is written as '\\', and double backslash as '\\\\'.
If m.goup(1) exists, then use the integer in m.group(1) to return a
unicode character.
Args:
m: match object
Returns:
String to replace matched object with.
"""
# Check if the matched strings are '\u' or '\\'.
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
# If m.group(1) exists, try and return unicode character.
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return _UNDEFINED_UNICODE
# Use match function to replace escaped substrings in the token.
return _UNESCAPE_REGEX.sub(match, token)
def _count_tokens(files, file_byte_limit=1e6):
"""Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear
in the samples. The samples are semi-evenly distributed across the file.
Args:
files: List of filepaths
file_byte_limit: Max number of bytes that will be read from each file.
Returns:
Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files.
"""
token_counts = collections.defaultdict(int)
for filepath in files:
with tf.gfile.Open(filepath, mode="r") as reader:
file_byte_budget = file_byte_limit
counter = 0
lines_to_skip = int(reader.size() / (file_byte_budget * 2))
for line in reader:
if counter < lines_to_skip:
counter += 1
else:
if file_byte_budget < 0:
break
line = line.strip()
file_byte_budget -= len(line)
counter = 0
# Add words to token counts
for token in _split_string_to_tokens(_native_to_unicode(line)):
token_counts[token] += 1
return token_counts
def _list_to_index_dict(lst):
"""Create dictionary mapping list items to their indices in the list."""
return {item: n for n, item in enumerate(lst)}
def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
"""Splits a token into subtokens defined in the subtoken dict."""
ret = []
start = 0
token_len = len(token)
while start < token_len:
# Find the longest subtoken, so iterate backwards.
for end in xrange(min(token_len, start + max_subtoken_length), start, -1):
subtoken = token[start:end]
if subtoken in subtoken_dict:
ret.append(subtoken)
start = end
break
else: # Did not break
# If there is no possible encoding of the escaped token then one of the
# characters in the token is not in the alphabet. This should be
# impossible and would be indicative of a bug.
raise ValueError("Was unable to split token \"%s\" into subtokens." %
token)
return ret
def _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_size, threshold, min_count=None,
reserved_tokens=None):
"""Generate subtoken vocabulary close to the target size."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if min_count is not None:
tf.logging.info("Using min_count=%d to generate vocab with target size %d" %
(min_count, target_size))
return _generate_subtokens(
token_counts, alphabet, min_count, reserved_tokens=reserved_tokens)
def bisect(min_val, max_val):
"""Recursive function to binary search for subtoken vocabulary."""
cur_count = (min_val + max_val) // 2
tf.logging.info("Binary search: trying min_count=%d (%d %d)" %
(cur_count, min_val, max_val))
subtoken_list = _generate_subtokens(
token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens)
val = len(subtoken_list)
tf.logging.info("Binary search: min_count=%d resulted in %d tokens" %
(cur_count, val))
within_threshold = abs(val - target_size) < threshold
if within_threshold or min_val >= max_val or cur_count < 2:
return subtoken_list
if val > target_size:
other_subtoken_list = bisect(cur_count + 1, max_val)
else:
other_subtoken_list = bisect(min_val, cur_count - 1)
# Return vocabulary dictionary with the closest number of tokens.
other_val = len(other_subtoken_list)
if abs(other_val - target_size) < abs(val - target_size):
return other_subtoken_list
return subtoken_list
tf.logging.info("Finding best min_count to get target size of %d" %
target_size)
return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT)
def _generate_alphabet_dict(iterable, reserved_tokens=None):
"""Create set of characters that appear in any element in the iterable."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
alphabet = {c for token in iterable for c in token}
alphabet |= {c for token in reserved_tokens for c in token}
alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set.
return alphabet
def _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length):
"""Count number of times subtokens appear, and generate new subtokens.
Args:
token_counts: dict mapping tokens to the number of times they appear in the
original files.
alphabet: list of allowed characters. Used to escape the tokens, which
guarantees that all tokens can be split into subtokens.
subtoken_dict: dict mapping subtokens to ids.
max_subtoken_length: maximum length of subtoken in subtoken_dict.
Returns:
A defaultdict mapping subtokens to the number of times they appear in the
tokens. The dict may contain new subtokens.
"""
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
token = _escape_token(token, alphabet)
subtokens = _split_token_to_subtokens(
token, subtoken_dict, max_subtoken_length)
# Generate new subtokens by taking substrings from token.
start = 0
for subtoken in subtokens:
for end in xrange(start + 1, len(token) + 1):
new_subtoken = token[start:end]
subtoken_counts[new_subtoken] += count
start += len(subtoken)
return subtoken_counts
def _filter_and_bucket_subtokens(subtoken_counts, min_count):
"""Return a bucketed list of subtokens that are filtered by count.
Args:
subtoken_counts: defaultdict mapping subtokens to their counts
min_count: int count used to filter subtokens
Returns:
List of subtoken sets, where subtokens in set i have the same length=i.
"""
# Create list of buckets, where subtokens in bucket i have length i.
subtoken_buckets = []
for subtoken, count in six.iteritems(subtoken_counts):
if count < min_count: # Filter out subtokens that don't appear enough
continue
while len(subtoken_buckets) <= len(subtoken):
subtoken_buckets.append(set())
subtoken_buckets[len(subtoken)].add(subtoken)
return subtoken_buckets
def _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens=None):
"""Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens
first). When a subtoken is added, the counts of each of its prefixes are
decreased. Prefixes that don't appear much outside the subtoken are not added
to the candidate list.
For example:
subtoken being added to candidate list: 'translate'
subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...}
min_count: 5
When 'translate' is added, subtoken_counts is updated to:
{'translate':0, 't':30, 'tr':6, 'tra': 2, ...}
The subtoken 'tra' will not be added to the candidate list, because it appears
twice (less than min_count) outside of 'translate'.
Args:
subtoken_counts: defaultdict mapping str subtokens to int counts
min_count: int minumum count requirement for subtokens
alphabet: set of characters. Each character is added to the subtoken list to
guarantee that all tokens can be encoded.
reserved_tokens: list of tokens that will be added to the beginning of the
returned subtoken list.
Returns:
List of candidate subtokens in decreasing count order, and maximum subtoken
length
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Create a list of (count, subtoken) for each candidate subtoken.
subtoken_candidates = []
# Use bucketted list to iterate through subtokens in order of length.
# subtoken_buckets[i] = set(subtokens), where each subtoken has length i.
subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count)
max_subtoken_length = len(subtoken_buckets) - 1
# Go through the list in reverse order to consider longer subtokens first.
for subtoken_len in xrange(max_subtoken_length, 0, -1):
for subtoken in subtoken_buckets[subtoken_len]:
count = subtoken_counts[subtoken]
# Possible if this subtoken is a prefix of another token.
if count < min_count:
continue
# Ignore alphabet/reserved tokens, which will be added manually later.
if subtoken not in alphabet and subtoken not in reserved_tokens:
subtoken_candidates.append((count, subtoken))
# Decrement count of the subtoken's prefixes (if a longer subtoken is
# added, its prefixes lose priority to be added).
for end in xrange(1, subtoken_len):
subtoken_counts[subtoken[:end]] -= count
# Add alphabet subtokens (guarantees that all strings are encodable).
subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet)
# Order subtoken candidates by decreasing count.
subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)]
# Add reserved tokens to beginning of the list.
subtoken_list = reserved_tokens + subtoken_list
return subtoken_list, max_subtoken_length
def _generate_subtokens(
token_counts, alphabet, min_count, num_iterations=4,
reserved_tokens=None):
"""Create a list of subtokens in decreasing order of frequency.
Args:
token_counts: dict mapping str tokens -> int count
alphabet: set of characters
min_count: int minimum number of times a subtoken must appear before it is
added to the vocabulary.
num_iterations: int number of iterations to generate new tokens.
reserved_tokens: list of tokens that will be added to the beginning to the
returned subtoken list.
Returns:
Sorted list of subtokens (most frequent first)
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Use alphabet set to create initial list of subtokens
subtoken_list = reserved_tokens + list(alphabet)
max_subtoken_length = 1
# On each iteration, segment all words using the subtokens defined in
# subtoken_dict, count how often the resulting subtokens appear, and update
# the dictionary with subtokens w/ high enough counts.
for i in xrange(num_iterations):
tf.logging.info("\tGenerating subtokens: iteration %d" % i)
# Generate new subtoken->id dictionary using the new subtoken list.
subtoken_dict = _list_to_index_dict(subtoken_list)
# Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens.
subtoken_counts = _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
# Generate new list of subtokens sorted by subtoken count.
subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
tf.logging.info("\tVocab size: %d" % len(subtoken_list))
return subtoken_list
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Subtokenizer and string helper methods."""
import collections
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.utils import tokenizer
class SubtokenizerTest(tf.test.TestCase):
def _init_subtokenizer(self, vocab_list):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w:
for subtoken in vocab_list:
w.write("'%s'" % subtoken)
w.write("\n")
return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[])
def test_encode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
s = "testing 123"
encoded_list = subtokenizer.encode(s)
self.assertEqual([1, 2, 0], encoded_list)
def test_decode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
decoded_str = subtokenizer.decode(encoded_list)
self.assertEqual("testing 123", decoded_str)
def test_subtoken_ids_to_tokens(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list)
self.assertEqual([u"testing", u"123"], token_list)
class StringHelperTest(tf.test.TestCase):
def test_split_string_to_tokens(self):
text = "test? testing 123."
tokens = tokenizer._split_string_to_tokens(text)
self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
def test_join_tokens_to_string(self):
tokens = ["test", "? ", "testing", "123", "."]
s = tokenizer._join_tokens_to_string(tokens)
self.assertEqual("test? testing 123.", s)
def test_escape_token(self):
token = u"abc_\\4"
alphabet = set("abc_\\u;")
escaped_token = tokenizer._escape_token(token, alphabet)
self.assertEqual("abc\\u\\\\\\52;_", escaped_token)
def test_unescape_token(self):
escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
unescaped_token = tokenizer._unescape_token(escaped_token)
self.assertEqual(
"Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
def test_list_to_index_dict(self):
lst = ["test", "strings"]
d = tokenizer._list_to_index_dict(lst)
self.assertDictEqual({"test": 0, "strings": 1}, d)
def test_split_token_to_subtokens(self):
token = "abc"
subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
max_subtoken_length = 2
subtokens = tokenizer._split_token_to_subtokens(
token, subtoken_dict, max_subtoken_length)
self.assertEqual(["ab", "c"], subtokens)
def test_generate_alphabet_dict(self):
s = ["testing", "123"]
reserved_tokens = ["???"]
alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens)
self.assertIn("?", alphabet)
self.assertIn("t", alphabet)
self.assertIn("e", alphabet)
self.assertIn("s", alphabet)
self.assertIn("i", alphabet)
self.assertIn("n", alphabet)
self.assertIn("g", alphabet)
self.assertIn("1", alphabet)
self.assertIn("2", alphabet)
self.assertIn("3", alphabet)
def test_count_and_gen_subtokens(self):
token_counts = {"abc": 5}
alphabet = set("abc_")
subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3}
max_subtoken_length = 2
subtoken_counts = tokenizer._count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
self.assertIsInstance(subtoken_counts, collections.defaultdict)
self.assertDictEqual(
{"a": 5, "b": 5, "c": 5, "_": 5, "ab": 5, "bc": 5, "c_": 5,
"abc": 5, "bc_": 5, "abc_": 5}, subtoken_counts)
def test_filter_and_bucket_subtokens(self):
subtoken_counts = collections.defaultdict(
int, {"a": 2, "b": 4, "c": 1, "ab": 6, "ac": 3, "abbc": 5})
min_count = 3
subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
subtoken_counts, min_count)
self.assertEqual(len(subtoken_buckets[0]), 0)
self.assertEqual(set("b"), subtoken_buckets[1])
self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2])
self.assertEqual(len(subtoken_buckets[3]), 0)
self.assertEqual(set(["abbc"]), subtoken_buckets[4])
def test_gen_new_subtoken_list(self):
subtoken_counts = collections.defaultdict(
int, {"translate": 10, "t": 40, "tr": 16, "tra": 12})
min_count = 5
alphabet = set("translate")
reserved_tokens = ["reserved", "tokens"]
subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
# Check that "tra" isn"t in the list (its count should be decremented to 2,
# so it should not be added to the canddiate list).
self.assertNotIn("tra", subtoken_list)
self.assertIn("tr", subtoken_list)
self.assertIn("t", subtoken_list)
self.assertEqual(len("translate"), max_token_length)
def test_generate_subtokens(self):
token_counts = {"ab": 1, "bc": 3, "abc": 5}
alphabet = set("abc_")
min_count = 100
num_iterations = 1
reserved_tokens = ["reserved", "tokens"]
vocab_list = tokenizer._generate_subtokens(
token_counts, alphabet, min_count, num_iterations, reserved_tokens)
# Check that reserved tokens are at the front of the list
self.assertEqual(vocab_list[:2], reserved_tokens)
# Check that each character in alphabet is in the vocab list
for c in alphabet:
self.assertIn(c, vocab_list)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""
import tensorflow as tf
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# to use the local CPU as the compute device. This is useful for testing and
# debugging; the code flow is ostensibly identical, but without the need to
# actually have a TPU on the other end.
LOCAL = "local"
def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
"""Construct a host call to log scalars when training on TPU.
Args:
metric_dict: A dict of the tensors to be logged.
model_dir: The location to write the summary.
prefix: The prefix (if any) to prepend to the metric names.
Returns:
A tuple of (function, args_to_be_passed_to_said_function)
"""
# type: (dict, str) -> (function, list)
metric_names = list(metric_dict.keys())
def host_call_fn(global_step, *args):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: `Tensor with shape `[batch]` for the global_step
*args: Remaining tensors to log.
Returns:
List of summary ops to run on the CPU host.
"""
step = global_step[0]
with tf.contrib.summary.create_file_writer(
logdir=model_dir, filename_suffix=".host_call").as_default():
with tf.contrib.summary.always_record_summaries():
for i, name in enumerate(metric_names):
tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
return tf.contrib.summary.all_summary_ops()
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors
def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
"""Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted
range(num_embedding_table_rows). All masked positions will produce an
embedding vector of zeros.
Args:
embedding_table: Tensor of embedding table.
Rank 2 (table_size x embedding dim)
values: Tensor of embedding indices. Rank 2 (batch x n_indices)
mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
name: Optional name scope for created ops
Returns:
Rank 3 tensor of embedding vectors.
"""
with tf.name_scope(name):
n_embeddings = embedding_table.get_shape().as_list()[0]
batch_size, padded_size = values.shape.as_list()
emb_idcs = tf.tile(
tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
emb_weights = tf.tile(
tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
col_idcs = tf.tile(
tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
(batch_size, padded_size, 1))
one_hot = tf.where(
tf.equal(emb_idcs, col_idcs), emb_weights,
tf.zeros((batch_size, padded_size, n_embeddings)))
return tf.tensordot(one_hot, embedding_table, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment