Commit 3b158095 authored by Ilya Mironov's avatar Ilya Mironov
Browse files

Merge branch 'master' of https://github.com/ilyamironov/models

parents a90db800 be659c2f
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability.
Source implementation from Tensor2Tensor:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py
"""
import tensorflow as tf
from tensorflow.python.util import nest
# Default value for INF
INF = 1. * 1e7
class _StateKeys(object):
"""Keys to dictionary storing the state of the beam search loop."""
# Variable storing the loop index.
CUR_INDEX = "CUR_INDEX"
# Top sequences that are alive for each batch item. Alive sequences are ones
# that have not generated an EOS token. Sequences that reach EOS are marked as
# finished and moved to the FINISHED_SEQ tensor.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]
ALIVE_SEQ = "ALIVE_SEQ"
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
# Dictionary of cached values for each alive sequence. The cache stores
# the encoder output, attention bias, and the decoder attention output from
# the previous iteration.
ALIVE_CACHE = "ALIVE_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
FINISHED_SEQ = "FINISHED_SEQ"
# Scores for each finished sequence. Score = log probability / length norm
# Shape [batch_size, beam_size]
FINISHED_SCORES = "FINISHED_SCORES"
# Flags indicating which sequences in the finished sequences are finished.
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
FINISHED_FLAGS = "FINISHED_FLAGS"
class SequenceBeamSearch(object):
"""Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id):
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
self.batch_size = batch_size
self.beam_size = beam_size
self.alpha = alpha
self.max_decode_length = max_decode_length
self.eos_id = eos_id
def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.while_loop(
self._continue_search, self._search_step, loop_vars=[state],
shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
finished_scores = tf.where(
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
return finished_seq, finished_scores
def _create_initial_state(self, initial_ids, initial_cache):
"""Return initial state dictionary and its shape invariants.
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor with shape [batch_size, 1]
initial_cache: dictionary storing values to be passed into the
symbols_to_logits_fn.
Returns:
state and shape invariant dictionaries with keys from _StateKeys
"""
# Current loop index (starts at 0)
cur_index = tf.constant(0)
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2)
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
initial_log_probs = tf.constant(
[[0.] + [-float("inf")] * (self.beam_size - 1)])
alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])
# Expand all values stored in the dictionary to the beam size, so that each
# beam has a separate cache.
alive_cache = nest.map_structure(
lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache)
# Initialize tensor storing finished sequences with filler values.
finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
# Set scores of the initial finished seqs to negative infinity.
finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF
# Initialize finished flags with all False values.
finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)
# Create state dictionary
state = {
_StateKeys.CUR_INDEX: cur_index,
_StateKeys.ALIVE_SEQ: alive_seq,
_StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
_StateKeys.ALIVE_CACHE: alive_cache,
_StateKeys.FINISHED_SEQ: finished_seq,
_StateKeys.FINISHED_SCORES: finished_scores,
_StateKeys.FINISHED_FLAGS: finished_flags
}
# Create state invariants for each value in the state dictionary. Each
# dimension must be a constant or None. A None dimension means either:
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
state_shape_invariants = {
_StateKeys.CUR_INDEX: tf.TensorShape([]),
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE: nest.map_structure(
_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants
def _continue_search(self, state):
"""Return whether to continue the search loop.
The loops should terminate when
1) when decode length has been reached, or
2) when the worst score in the finished sequences is better than the best
score in the alive sequences (i.e. the finished sequences are provably
unchanging)
Args:
state: A dictionary with the current loop state.
Returns:
Bool tensor with value True if loop should continue, False if loop should
terminate.
"""
i = state[_StateKeys.CUR_INDEX]
alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
finished_scores = state[_StateKeys.FINISHED_SCORES]
finished_flags = state[_StateKeys.FINISHED_FLAGS]
not_at_max_decode_length = tf.less(i, self.max_decode_length)
# Calculate largest length penalty (the larger penalty, the better score).
max_length_norm = _length_normalization(self.alpha, self.max_decode_length)
# Get the best possible scores from alive sequences.
best_alive_scores = alive_log_probs[:, 0] / max_length_norm
# Compute worst score in finished sequences for each batch element
finished_scores *= tf.to_float(finished_flags) # set filler scores to zero
lowest_finished_scores = tf.reduce_min(finished_scores, axis=1)
# If there are no finished sequences in a batch element, then set the lowest
# finished score to -INF for that element.
finished_batches = tf.reduce_any(finished_flags, 1)
lowest_finished_scores += (1. - tf.to_float(finished_batches)) * -INF
worst_finished_score_better_than_best_alive_score = tf.reduce_all(
tf.greater(lowest_finished_scores, best_alive_scores)
)
return tf.logical_and(
not_at_max_decode_length,
tf.logical_not(worst_finished_score_better_than_best_alive_score)
)
def _search_step(self, state):
"""Beam search loop body.
Grow alive sequences by a single ID. Sequences that have reached the EOS
token are marked as finished. The alive and finished sequences with the
highest log probabilities and scores are returned.
A sequence's finished score is calculating by dividing the log probability
by the length normalization factor. Without length normalization, the
search is more likely to return shorter sequences.
Args:
state: A dictionary with the current loop state.
Returns:
new state dictionary.
"""
# Grow alive sequences by one token.
new_seq, new_log_probs, new_cache = self._grow_alive_seq(state)
# Collect top beam_size alive sequences
alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache)
# Combine newly finished sequences with existing finished sequences, and
# collect the top k scoring sequences.
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs)
# Increment loop index and create new state dictionary
new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
new_state.update(alive_state)
new_state.update(finished_state)
return [new_state]
def _grow_alive_seq(self, state):
"""Grow alive sequences by one token, and collect top 2*beam_size sequences.
2*beam_size sequences are collected because some sequences may have reached
the EOS token. 2*beam_size ensures that at least beam_size sequences are
still alive.
Args:
state: A dictionary with the current loop state.
Returns:
Tuple of
(Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1],
Scores of returned sequences [batch_size, 2 * beam_size],
New alive cache, for each of the 2 * beam_size sequences)
"""
i = state[_StateKeys.CUR_INDEX]
alive_seq = state[_StateKeys.ALIVE_SEQ]
alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
alive_cache = state[_StateKeys.ALIVE_CACHE]
beams_to_keep = 2 * self.beam_size
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
# Unflatten logits to shape [batch_size, beam_size, vocab_size]
logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size)
new_cache = nest.map_structure(
lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size),
flat_cache)
# Convert logits to normalized log probs
candidate_log_probs = _log_prob_from_logits(logits)
# Calculate new log probabilities if each of the alive sequences were
# extended # by the the candidate IDs.
# Shape [batch_size, beam_size, vocab_size]
log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)
# Each batch item has beam_size * vocab_size candidate sequences. For each
# batch item, get the k candidates with the highest log probabilities.
flat_log_probs = tf.reshape(log_probs,
[-1, self.beam_size * self.vocab_size])
topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)
# Extract the alive sequences that generate the highest log probabilities
# after being extended.
topk_beam_indices = topk_indices // self.vocab_size
topk_seq, new_cache = _gather_beams(
[alive_seq, new_cache], topk_beam_indices, self.batch_size,
beams_to_keep)
# Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
"""Gather the top k sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences
float32 tensor with shape [batch_size, beam_size]
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys from _StateKeys:
{Top beam_size sequences that are still alive (don't end with eos_id)
Log probabilities of top alive sequences
Dict cache storing decoder states for top alive sequences}
"""
# To prevent finished sequences from being considered, set log probs to -INF
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_log_probs += tf.to_float(new_finished_flags) * -INF
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
[new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size,
self.beam_size)
return {
_StateKeys.ALIVE_SEQ: top_alive_seq,
_StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs,
_StateKeys.ALIVE_CACHE: top_alive_cache
}
def _get_new_finished_state(self, state, new_seq, new_log_probs):
"""Combine new and old finished sequences, and gather the top k sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, beam_size, i + 1]
new_log_probs: Log probabilities of new sequences
float32 tensor with shape [batch_size, beam_size]
Returns:
Dictionary with finished keys from _StateKeys:
{Top beam_size finished sequences based on score,
Scores of finished sequences,
Finished flags of finished sequences}
"""
i = state[_StateKeys.CUR_INDEX]
finished_seq = state[_StateKeys.FINISHED_SEQ]
finished_scores = state[_StateKeys.FINISHED_SCORES]
finished_flags = state[_StateKeys.FINISHED_FLAGS]
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat(
[finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
# Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1)
new_scores = new_log_probs / length_norm
# Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_scores += (1. - tf.to_float(new_finished_flags)) * -INF
# Combine sequences, scores, and flags.
finished_seq = tf.concat([finished_seq, new_seq], axis=1)
finished_scores = tf.concat([finished_scores, new_scores], axis=1)
finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1)
# Return the finished sequences with the best scores.
top_finished_seq, top_finished_scores, top_finished_flags = (
_gather_topk_beams([finished_seq, finished_scores, finished_flags],
finished_scores, self.batch_size, self.beam_size))
return {
_StateKeys.FINISHED_SEQ: top_finished_seq,
_StateKeys.FINISHED_SCORES: top_finished_scores,
_StateKeys.FINISHED_FLAGS: top_finished_flags
}
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id)
return sbs.search(initial_ids, initial_cache)
def _log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True)
def _length_normalization(alpha, length):
"""Return length normalization factor."""
return tf.pow(((5. + tf.to_float(length)) / 6.), alpha)
def _expand_to_beam_size(tensor, beam_size):
"""Tiles a given tensor by beam_size.
Args:
tensor: tensor to tile [batch_size, ...]
beam_size: How much to tile the tensor by.
Returns:
Tiled tensor [batch_size, beam_size, ...]
"""
tensor = tf.expand_dims(tensor, axis=1)
tile_dims = [1] * tensor.shape.ndims
tile_dims[1] = beam_size
return tf.tile(tensor, tile_dims)
def _shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
def _get_shape_keep_last_dim(tensor):
shape_list = _shape_list(tensor)
# Only the last
for i in range(len(shape_list) - 1):
shape_list[i] = None
if isinstance(shape_list[-1], tf.Tensor):
shape_list[-1] = None
return tf.TensorShape(shape_list)
def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension.
Args:
tensor: Tensor to reshape of shape [A, B, ...]
Returns:
Reshaped tensor of shape [A*B, ...]
"""
shape = _shape_list(tensor)
shape[0] *= shape[1]
shape.pop(1) # Remove beam dim
return tf.reshape(tensor, shape)
def _unflatten_beam_dim(tensor, batch_size, beam_size):
"""Reshapes first dimension back to [batch_size, beam_size].
Args:
tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
batch_size: Tensor, original batch size.
beam_size: int, original beam size.
Returns:
Reshaped tensor of shape [batch_size, beam_size, ...]
"""
shape = _shape_list(tensor)
new_shape = [batch_size, beam_size] + shape[1:]
return tf.reshape(tensor, new_shape)
def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates = tf.stack([batch_pos, beam_indices], axis=2)
return nest.map_structure(
lambda state: tf.gather_nd(state, coordinates), nested)
def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size):
"""Gather top beams from nested structure."""
_, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size)
return _gather_beams(nested, topk_indexes, batch_size, beam_size)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test beam search helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import beam_search
class BeamSearchHelperTests(tf.test.TestCase):
def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5])
x = beam_search._expand_to_beam_size(x, 3)
with self.test_session() as sess:
shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 3, 4, 2, 5], shape)
def test_shape_list(self):
y = tf.constant(4.0)
x = tf.ones([7, tf.to_int32(tf.sqrt(y)), 2, 5])
shape = beam_search._shape_list(x)
self.assertIsInstance(shape[0], int)
self.assertIsInstance(shape[1], tf.Tensor)
self.assertIsInstance(shape[2], int)
self.assertIsInstance(shape[3], int)
def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0)
x = tf.ones([7, tf.to_int32(tf.sqrt(y)), 2, 5])
shape = beam_search._get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5],
shape.as_list())
def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5])
x = beam_search._flatten_beam_dim(x)
with self.test_session() as sess:
shape = sess.run(tf.shape(x))
self.assertAllEqual([28, 2, 5], shape)
def test_unflatten_beam_dim(self):
x = tf.ones([28, 2, 5])
x = beam_search._unflatten_beam_dim(x, 7, 4)
with self.test_session() as sess:
shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 4, 2, 5], shape)
def test_gather_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4])
# x looks like: [[[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
#
# [[12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]]]
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
with self.test_session() as sess:
y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7],
[8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]],
y)
def test_gather_topk_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4])
x_scores = [[0, 1, 1], [1, 0, 1]]
y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
with self.test_session() as sess:
y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7],
[8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]],
y)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of embedding layer with shared weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
class EmbeddingSharedWeights(tf.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size):
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
def build(self, _):
with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
# Create and initialize weights. The random normal initializer was chosen
# randomly, and works well.
self.shared_weights = tf.get_variable(
"weights", [self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
0., self.hidden_size ** -0.5))
self.built = True
def call(self, x):
"""Get token embeddings of x.
Args:
x: An int64 tensor with shape [batch_size, length]
Returns:
embeddings: float32 tensor with shape [batch_size, length, embedding_size]
padding: float32 tensor with shape [batch_size, length] indicating the
locations of the padding tokens in x.
"""
with tf.name_scope("embedding"):
embeddings = tf.gather(self.shared_weights, x)
# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5
# Create binary array of size [batch_size, length]
# where 1 = padding, 0 = not padding
padding = model_utils.get_padding(x)
# Set all padding embedding values to 0
embeddings *= tf.expand_dims(1 - padding, -1)
return embeddings
def linear(self, x):
"""Computes logits by running x through a linear layer.
Args:
x: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
with tf.name_scope("presoftmax_linear"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
x = tf.reshape(x, [-1, self.hidden_size])
logits = tf.matmul(x, self.shared_weights, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of fully connected network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class FeedFowardNetwork(tf.layers.Layer):
"""Fully connected feedforward network."""
def __init__(self, hidden_size, filter_size, relu_dropout, train):
super(FeedFowardNetwork, self).__init__()
self.hidden_size = hidden_size
self.filter_size = filter_size
self.relu_dropout = relu_dropout
self.train = train
self.filter_dense_layer = tf.layers.Dense(
filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
self.output_dense_layer = tf.layers.Dense(
hidden_size, use_bias=True, name="output_layer")
def call(self, x, padding=None):
"""Return outputs of the feedforward network.
Args:
x: tensor with shape [batch_size, length, hidden_size]
padding: (optional) If set, the padding values are temporarily removed
from x. The padding values are placed back in the output tensor in the
same locations. shape [batch_size, length]
Returns:
Output of the feedforward network.
tensor with shape [batch_size, length, hidden_size]
"""
# Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
if padding is not None:
with tf.name_scope("remove_padding"):
# Flatten padding to [batch_size*length]
pad_mask = tf.reshape(padding, [-1])
nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9))
# Reshape x to [batch_size*length, hidden_size] to remove padding
x = tf.reshape(x, [-1, self.hidden_size])
x = tf.gather_nd(x, indices=nonpad_ids)
# Reshape x from 2 dimensions to 3 dimensions.
x.set_shape([None, self.hidden_size])
x = tf.expand_dims(x, axis=0)
output = self.filter_dense_layer(x)
if self.train:
output = tf.nn.dropout(output, 1.0 - self.relu_dropout)
output = self.output_dense_layer(output)
if padding is not None:
with tf.name_scope("re_add_padding"):
output = tf.squeeze(output, axis=0)
output = tf.scatter_nd(
indices=nonpad_ids,
updates=output,
shape=[batch_size * length, self.hidden_size]
)
output = tf.reshape(output, [batch_size, length, self.hidden_size])
return output
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Transformer model parameters."""
class TransformerBaseParams(object):
"""Parameters for the base Transformer model."""
# Input params
batch_size = 2048 # Maximum number of tokens per batch of examples.
max_length = 256 # Maximum number of tokens per example.
# Model params
initializer_gain = 1.0 # Used in trainable variable initialization.
vocab_size = 33708 # Number of tokens defined in the vocabulary file.
hidden_size = 512 # Model dimension in the hidden layers.
num_hidden_layers = 6 # Number of layers in the encoder and decoder stacks.
num_heads = 8 # Number of heads to use in multi-headed attention.
filter_size = 2048 # Inner layer dimensionality in the feedforward network.
# Dropout values (only used when training)
layer_postprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# Training params
label_smoothing = 0.1
learning_rate = 2.0
learning_rate_decay_rate = 1.0
learning_rate_warmup_steps = 16000
# Optimizer params
optimizer_adam_beta1 = 0.9
optimizer_adam_beta2 = 0.997
optimizer_adam_epsilon = 1e-09
# Default prediction params
extra_decode_length = 50
beam_size = 4
alpha = 0.6 # used to calculate length normalization in beam search
class TransformerBigParams(TransformerBaseParams):
"""Parameters for the big Transformer model."""
batch_size = 4096
hidden_size = 1024
filter_size = 4096
num_heads = 16
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer model helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
_NEG_INF = -1e9
def get_position_encoding(
length, hidden_size, min_timescale=1.0, max_timescale=1.0e4):
"""Return positional encoding.
Calculates the position encoding as a mix of sine and cosine functions with
geometrically increasing wavelengths.
Defined and formulized in Attention is All You Need, section 3.5.
Args:
length: Sequence length.
hidden_size: Size of the
min_timescale: Minimum scale that will be applied at each position
max_timescale: Maximum scale that will be applied at each position
Returns:
Tensor with shape [length, hidden_size]
"""
position = tf.to_float(tf.range(length))
num_timescales = hidden_size // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return signal
def get_decoder_self_attention_bias(length):
"""Calculate bias for decoder that maintains model's autoregressive property.
Creates a tensor that masks out locations that correspond to illegal
connections, so prediction at position i cannot draw information from future
positions.
Args:
length: int length of sequences in batch.
Returns:
float tensor of shape [1, 1, length, length]
"""
with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.matrix_band_part(tf.ones([length, length]), -1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = _NEG_INF * (1.0 - valid_locs)
return decoder_bias
def get_padding(x, padding_value=0):
"""Return float tensor representing the padding values in x.
Args:
x: int tensor with any shape
padding_value: int value that
Returns:
flaot tensor with same shape as x containing values 0 or 1.
0 -> non-padding, 1 -> padding
"""
with tf.name_scope("padding"):
return tf.to_float(tf.equal(x, padding_value))
def get_padding_bias(x):
"""Calculate bias tensor from padding values in tensor.
Bias tensor that is added to the pre-softmax multi-headed attention logits,
which has shape [batch_size, num_heads, length, length]. The tensor is zero at
non-padding locations, and -1e9 (negative infinity) at padding locations.
Args:
x: int tensor with shape [batch_size, length]
Returns:
Attention bias tensor of shape [batch_size, 1, 1, length].
"""
with tf.name_scope("attention_bias"):
padding = get_padding(x)
attention_bias = padding * _NEG_INF
attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1)
return attention_bias
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Transformer model helper methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import model_utils
NEG_INF = -1e9
class ModelUtilsTest(tf.test.TestCase):
def test_get_padding(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
padding = model_utils.get_padding(x, padding_value=0)
with self.test_session() as sess:
padding = sess.run(padding)
self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
padding)
def test_get_padding_bias(self):
x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
bias = model_utils.get_padding_bias(x)
bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5])
with self.test_session() as sess:
flattened_bias, bias_shape = sess.run((flattened_bias, bias_shape))
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[NEG_INF, 0, 0, NEG_INF, 0]],
flattened_bias)
self.assertAllEqual([3, 1, 1, 5], bias_shape)
def test_get_decoder_self_attention_bias(self):
length = 5
bias = model_utils.get_decoder_self_attention_bias(length)
with self.test_session() as sess:
bias = sess.run(bias)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[0, 0, 0, NEG_INF, NEG_INF],
[0, 0, 0, 0, NEG_INF],
[0, 0, 0, 0, 0]]]],
bias)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model, and its encoder and decoder stacks.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.model import attention_layer
from official.transformer.model import beam_search
from official.transformer.model import embedding_layer
from official.transformer.model import ffn_layer
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID
_NEG_INF = -1e9
class Transformer(object):
"""Transformer model for sequence to sequence data.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def __init__(self, params, train):
"""Initialize layers to build Transformer model.
Args:
params: hyperparameter object defining layer sizes, dropout values, etc.
train: boolean indicating whether the model is in training mode. Used to
determine if dropout layers should be added.
"""
self.train = train
self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params.vocab_size, params.hidden_size)
self.encoder_stack = EncoderStack(params, train)
self.decoder_stack = DecoderStack(params, train)
def __call__(self, inputs, targets=None):
"""Calculate target logits or inferred target sequences.
Args:
inputs: int tensor with shape [batch_size, input_length].
targets: None or int tensor with shape [batch_size, target_length].
Returns:
If targets is defined, then return logits for each word in the target
sequence. float tensor with shape [batch_size, target_length, vocab_size]
If target is none, then generate output sequence one token at a time.
returns a dictionary {
output: [batch_size, decoded length]
score: [batch_size, float]}
"""
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
initializer = tf.variance_scaling_initializer(
self.params.initializer_gain, mode="fan_avg", distribution="uniform")
with tf.variable_scope("Transformer", initializer=initializer):
# Calculate attention bias for encoder self-attention and decoder
# multi-headed attention layers.
attention_bias = model_utils.get_padding_bias(inputs)
# Run the inputs through the encoder layer to map the symbol
# representations to continuous representations.
encoder_outputs = self.encode(inputs, attention_bias)
# Generate output sequence if targets is None, or return logits if target
# sequence is known.
if targets is None:
return self.predict(encoder_outputs, attention_bias)
else:
logits = self.decode(targets, encoder_outputs, attention_bias)
return logits
def encode(self, inputs, attention_bias):
"""Generate continuous representation for inputs.
Args:
inputs: int tensor with shape [batch_size, input_length].
attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
Returns:
float tensor with shape [batch_size, input_length, hidden_size]
"""
with tf.name_scope("encode"):
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_softmax_layer(inputs)
inputs_padding = model_utils.get_padding(inputs)
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params.hidden_size)
encoder_inputs = embedded_inputs + pos_encoding
if self.train:
encoder_inputs = tf.nn.dropout(
encoder_inputs, 1 - self.params.layer_postprocess_dropout)
return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)
def decode(self, targets, encoder_outputs, attention_bias):
"""Generate logits for each value in the target sequence.
Args:
targets: target values for the output sequence.
int tensor with shape [batch_size, target_length]
encoder_outputs: continuous representation of input sequence.
float tensor with shape [batch_size, input_length, hidden_size]
attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
Returns:
float32 tensor with shape [batch_size, target_length, vocab_size]
"""
with tf.name_scope("decode"):
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(
decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding(
length, self.params.hidden_size)
if self.train:
decoder_inputs = tf.nn.dropout(
decoder_inputs, 1 - self.params.layer_postprocess_dropout)
# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
length)
outputs = self.decoder_stack(
decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias)
logits = self.embedding_softmax_layer.linear(outputs)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params.hidden_size)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length)
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences.
int tensor with shape [batch_size * beam_size, i + 1]
i: Loop index
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape [batch_size * beam_size, vocab_size],
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack(
decoder_input, cache.get("encoder_outputs"), self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache)
logits = self.embedding_softmax_layer.linear(decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
return symbols_to_logits_fn
def predict(self, encoder_outputs, encoder_decoder_attention_bias):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params.extra_decode_length
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params.hidden_size]),
"v": tf.zeros([batch_size, 0, self.params.hidden_size]),
} for layer in range(self.params.num_hidden_layers)}
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params.vocab_size,
beam_size=self.params.beam_size,
alpha=self.params.alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
class LayerNormalization(tf.layers.Layer):
"""Applies layer normalization."""
def __init__(self, hidden_size):
super(LayerNormalization, self).__init__()
self.hidden_size = hidden_size
def build(self, _):
self.scale = tf.get_variable("layer_norm_scale", [self.hidden_size],
initializer=tf.ones_initializer())
self.bias = tf.get_variable("layer_norm_bias", [self.hidden_size],
initializer=tf.zeros_initializer())
self.built = True
def call(self, x, epsilon=1e-6):
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
return norm_x * self.scale + self.bias
class PrePostProcessingWrapper(object):
"""Wrapper class that applies layer pre-processing and post-processing."""
def __init__(self, layer, params, train):
self.layer = layer
self.postprocess_dropout = params.layer_postprocess_dropout
self.train = train
# Create normalization layer
self.layer_norm = LayerNormalization(params.hidden_size)
def __call__(self, x, *args, **kwargs):
# Preprocessing: apply layer normalization
y = self.layer_norm(x)
# Get layer output
y = self.layer(y, *args, **kwargs)
# Postprocessing: apply dropout and residual connection
if self.train:
y = tf.nn.dropout(y, 1 - self.postprocess_dropout)
return x + y
class EncoderStack(tf.layers.Layer):
"""Transformer encoder stack.
The encoder stack is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self, params, train):
super(EncoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
# Create sublayers for each layer.
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
# Create final layer normalization layer.
self.output_normalization = LayerNormalization(params.hidden_size)
def call(self, encoder_inputs, attention_bias, inputs_padding):
"""Return the output of the encoder layer stacks.
Args:
encoder_inputs: tensor with shape [batch_size, input_length, hidden_size]
attention_bias: bias for the encoder self-attention layer.
[batch_size, 1, 1, input_length]
inputs_padding: P
Returns:
Output of encoder layer stack.
float32 tensor with shape [batch_size, input_length, hidden_size]
"""
for n, layer in enumerate(self.layers):
# Run inputs through the sublayers.
self_attention_layer = layer[0]
feed_forward_network = layer[1]
with tf.variable_scope("layer_%d" % n):
with tf.variable_scope("self_attention"):
encoder_inputs = self_attention_layer(encoder_inputs, attention_bias)
with tf.variable_scope("ffn"):
encoder_inputs = feed_forward_network(encoder_inputs, inputs_padding)
return self.output_normalization(encoder_inputs)
class DecoderStack(tf.layers.Layer):
"""Transformer decoder stack.
Like the encoder stack, the decoder stack is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self, params, train):
super(DecoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
enc_dec_attention_layer = attention_layer.Attention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(enc_dec_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])
self.output_normalization = LayerNormalization(params.hidden_size)
def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias, cache=None):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
decoder_self_attention_bias: bias for decoder self-attention layer.
[1, 1, target_len, target_length]
attention_bias: bias for encoder-decoder attention layer.
[batch_size, 1, 1, input_length]
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
...}
Returns:
Output of decoder layer stack.
float32 tensor with shape [batch_size, target_length, hidden_size]
"""
for n, layer in enumerate(self.layers):
self_attention_layer = layer[0]
enc_dec_attention_layer = layer[1]
feed_forward_network = layer[2]
# Run inputs through the sublayers.
layer_name = "layer_%d" % n
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope("self_attention"):
decoder_inputs = self_attention_layer(
decoder_inputs, decoder_self_attention_bias, cache=layer_cache)
with tf.variable_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer(
decoder_inputs, encoder_outputs, attention_bias)
with tf.variable_scope("ffn"):
decoder_inputs = feed_forward_network(decoder_inputs)
return self.output_normalization(decoder_inputs)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Creates an estimator to train the Transformer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tempfile
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import tokenizer
DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
INF = int(1e9)
def model_fn(features, labels, mode, params):
"""Defines how to train, evaluate and predict from the transformer model."""
with tf.variable_scope("model"):
inputs, targets = features, labels
# Create model and get output logits.
model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)
output = model(inputs, targets)
# When in prediction mode, the labels/targets is None. The model output
# is the prediction
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=output)
logits = output
# Calculate model loss.
xentropy, weights = metrics.padded_cross_entropy_loss(
logits, targets, params.label_smoothing, params.vocab_size)
loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
else:
train_op = get_train_op(loss, params)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def get_learning_rate(learning_rate, hidden_size, learning_rate_warmup_steps):
"""Calculate learning rate with linear warmup and rsqrt decay."""
with tf.name_scope("learning_rate"):
warmup_steps = tf.to_float(learning_rate_warmup_steps)
step = tf.to_float(tf.train.get_or_create_global_step())
learning_rate *= (hidden_size ** -0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, step / warmup_steps)
# Apply rsqrt decay
learning_rate *= tf.rsqrt(tf.maximum(step, warmup_steps))
# Save learning rate value to TensorBoard summary.
tf.summary.scalar("learning_rate", learning_rate)
return learning_rate
def get_train_op(loss, params):
"""Generate training operation that updates variables based on loss."""
with tf.variable_scope("get_train_op"):
learning_rate = get_learning_rate(
params.learning_rate, params.hidden_size,
params.learning_rate_warmup_steps)
# Create optimizer. Use LazyAdamOptimizer from TF contrib, which is faster
# than the TF core Adam optimizer.
optimizer = tf.contrib.opt.LazyAdamOptimizer(
learning_rate,
beta1=params.optimizer_adam_beta1,
beta2=params.optimizer_adam_beta2,
epsilon=params.optimizer_adam_epsilon)
# Calculate and apply gradients using LazyAdamOptimizer.
global_step = tf.train.get_global_step()
tvars = tf.trainable_variables()
gradients = optimizer.compute_gradients(
loss, tvars, colocate_gradients_with_ops=True)
train_op = optimizer.apply_gradients(
gradients, global_step=global_step, name="train")
# Save gradient norm to Tensorboard
tf.summary.scalar("global_norm/gradient_norm",
tf.global_norm(list(zip(*gradients))[0]))
return train_op
def translate_and_compute_bleu(estimator, subtokenizer, bleu_source, bleu_ref):
"""Translate file and report the cased and uncased bleu scores."""
# Create temporary file to store translation.
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp_filename = tmp.name
translate.translate_file(
estimator, subtokenizer, bleu_source, output_file=tmp_filename,
print_all_translations=False)
# Compute uncased and cased bleu scores.
uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
os.remove(tmp_filename)
return uncased_score, cased_score
def get_global_step(estimator):
"""Return estimator's last checkpoint."""
return int(estimator.latest_checkpoint().split("-")[-1])
def evaluate_and_log_bleu(estimator, bleu_writer, bleu_source, bleu_ref):
"""Calculate and record the BLEU score."""
subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref)
print("Bleu score (uncased):", uncased_score)
print("Bleu score (cased):", cased_score)
summary = tf.Summary(value=[
tf.Summary.Value(tag="bleu/uncased", simple_value=uncased_score),
tf.Summary.Value(tag="bleu/cased", simple_value=cased_score),
])
bleu_writer.add_summary(summary, get_global_step(estimator))
bleu_writer.flush()
return uncased_score, cased_score
def train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps=None,
single_iteration_train_epochs=None, bleu_source=None, bleu_ref=None,
bleu_threshold=None):
"""Train and evaluate model, and optionally compute model's BLEU score.
**Step vs. Epoch vs. Iteration**
Steps and epochs are canonical terms used in TensorFlow and general machine
learning. They are used to describe running a single process (train/eval):
- Step refers to running the process through a single or batch of examples.
- Epoch refers to running the process through an entire dataset.
E.g. training a dataset with 100 examples. The dataset is
divided into 20 batches with 5 examples per batch. A single training step
trains the model on one batch. After 20 training steps, the model will have
trained on every batch in the dataset, or, in other words, one epoch.
Meanwhile, iteration is used in this implementation to describe running
multiple processes (training and eval).
- A single iteration:
1. trains the model for a specific number of steps or epochs.
2. evaluates the model.
3. (if source and ref files are provided) compute BLEU score.
This function runs through multiple train+eval+bleu iterations.
Args:
estimator: tf.Estimator containing model to train.
train_eval_iterations: Number of times to repeat the train+eval iteration.
single_iteration_train_steps: Number of steps to train in one iteration.
single_iteration_train_epochs: Number of epochs to train in one iteration.
bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped.
Raises:
ValueError: if both or none of single_iteration_train_steps and
single_iteration_train_epochs were defined.
"""
# Ensure that exactly one of single_iteration_train_steps and
# single_iteration_train_epochs is defined.
if single_iteration_train_steps is None:
if single_iteration_train_epochs is None:
raise ValueError(
"Exactly one of single_iteration_train_steps or "
"single_iteration_train_epochs must be defined. Both were none.")
else:
if single_iteration_train_epochs is not None:
raise ValueError(
"Exactly one of single_iteration_train_steps or "
"single_iteration_train_epochs must be defined. Both were defined.")
evaluate_bleu = bleu_source is not None and bleu_ref is not None
# Print out training schedule
print("Training schedule:")
if single_iteration_train_epochs is not None:
print("\t1. Train for %d epochs." % single_iteration_train_epochs)
else:
print("\t1. Train for %d steps." % single_iteration_train_steps)
print("\t2. Evaluate model.")
if evaluate_bleu:
print("\t3. Compute BLEU score.")
if bleu_threshold is not None:
print("Repeat above steps until the BLEU score reaches", bleu_threshold)
if not evaluate_bleu or bleu_threshold is None:
print("Repeat above steps %d times." % train_eval_iterations)
if evaluate_bleu:
# Set summary writer to log bleu score.
bleu_writer = tf.summary.FileWriter(
os.path.join(estimator.model_dir, BLEU_DIR))
if bleu_threshold is not None:
# Change loop stopping condition if bleu_threshold is defined.
train_eval_iterations = INF
# Loop training/evaluation/bleu cycles
for i in xrange(train_eval_iterations):
print("Starting iteration", i + 1)
# Train the model for single_iteration_train_steps or until the input fn
# runs out of examples (if single_iteration_train_steps is None).
estimator.train(dataset.train_input_fn, steps=single_iteration_train_steps)
eval_results = estimator.evaluate(dataset.eval_input_fn)
print("Evaluation results (iter %d/%d):" % (i + 1, train_eval_iterations),
eval_results)
if evaluate_bleu:
uncased_score, _ = evaluate_and_log_bleu(
estimator, bleu_writer, bleu_source, bleu_ref)
if bleu_threshold is not None and uncased_score > bleu_threshold:
bleu_writer.close()
break
def main(_):
# Set logging level to INFO to display training progress (logged by the
# estimator)
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Determine training schedule based on flags.
if FLAGS.train_steps is not None and FLAGS.train_epochs is not None:
raise ValueError("Both --train_steps and --train_epochs were set. Only one "
"may be defined.")
if FLAGS.train_steps is not None:
train_eval_iterations = FLAGS.train_steps // FLAGS.steps_between_eval
single_iteration_train_steps = FLAGS.steps_between_eval
single_iteration_train_epochs = None
else:
if FLAGS.train_epochs is None:
FLAGS.train_epochs = DEFAULT_TRAIN_EPOCHS
train_eval_iterations = FLAGS.train_epochs // FLAGS.epochs_between_eval
single_iteration_train_steps = None
single_iteration_train_epochs = FLAGS.epochs_between_eval
# Make sure that the BLEU source and ref files if set
if FLAGS.bleu_source is not None and FLAGS.bleu_ref is not None:
if not tf.gfile.Exists(FLAGS.bleu_source):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_source)
if not tf.gfile.Exists(FLAGS.bleu_ref):
raise ValueError("BLEU source file %s does not exist" % FLAGS.bleu_ref)
# Add flag-defined parameters to params object
params.data_dir = FLAGS.data_dir
params.num_cpu_cores = FLAGS.num_cpu_cores
params.epochs_between_eval = FLAGS.epochs_between_eval
params.repeat_dataset = single_iteration_train_epochs
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=FLAGS.model_dir, params=params)
train_schedule(
estimator, train_eval_iterations, single_iteration_train_steps,
single_iteration_train_epochs, FLAGS.bleu_source, FLAGS.bleu_ref,
FLAGS.bleu_threshold)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory containing training and "
"evaluation data, and vocab file used for encoding.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory to save Transformer model "
"training checkpoints",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter set to use when creating and "
"training the model.",
metavar="<P>")
parser.add_argument(
"--num_cpu_cores", "-nc", type=int, default=4,
help="[default: %(default)s] Number of CPU cores to use in the input "
"pipeline.",
metavar="<NC>")
# Flags for training with epochs. (default)
parser.add_argument(
"--train_epochs", "-te", type=int, default=None,
help="The number of epochs used to train. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TE>")
parser.add_argument(
"--epochs_between_eval", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<TE>")
# Flags for training with steps (may be used for debugging)
parser.add_argument(
"--train_steps", "-ts", type=int, default=None,
help="Total number of training steps. If both --train_epochs and "
"--train_steps are not set, the model will train for %d epochs." %
DEFAULT_TRAIN_EPOCHS,
metavar="<TS>")
parser.add_argument(
"--steps_between_eval", "-sbe", type=int, default=1000,
help="[default: %(default)s] Number of training steps to run between "
"evaluations.",
metavar="<SBE>")
# BLEU score computation
parser.add_argument(
"--bleu_source", "-bs", type=str, default=None,
help="Path to source file containing text translate when calculating the "
"official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BS>")
parser.add_argument(
"--bleu_ref", "-br", type=str, default=None,
help="Path to file containing the reference translation for calculating "
"the official BLEU score. Both --bleu_source and --bleu_ref must be "
"set. The BLEU score will be calculated during model evaluation.",
metavar="<BR>")
parser.add_argument(
"--bleu_threshold", "-bt", type=float, default=None,
help="Stop training when the uncased BLEU score reaches this value. "
"Setting this overrides the total number of steps or epochs set by "
"--train_steps or --train_epochs.",
metavar="<BT>")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Translate text or files using trained transformer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
# pylint: disable=g-bad-import-order
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.utils import tokenizer
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4
_ALPHA = 0.6
def _get_sorted_inputs(filename):
"""Read and sort lines from the file sorted by decreasing length.
Args:
filename: String name of file to read inputs from.
Returns:
Sorted list of inputs, and dictionary mapping original index->sorted index
of each element.
"""
with tf.gfile.Open(filename) as f:
records = f.read().split("\n")
inputs = [record.strip() for record in records]
if not inputs[-1]:
inputs.pop()
input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
sorted_input_lens = sorted(input_lens, key=lambda x: x[1], reverse=True)
sorted_inputs = []
sorted_keys = {}
for i, (index, _) in enumerate(sorted_input_lens):
sorted_inputs.append(inputs[index])
sorted_keys[index] = i
return sorted_inputs, sorted_keys
def _encode_and_add_eos(line, subtokenizer):
"""Encode line with subtokenizer, and add EOS id to the end."""
return subtokenizer.encode(line) + [tokenizer.EOS_ID]
def _trim_and_decode(ids, subtokenizer):
"""Trim EOS and PAD tokens from ids, and decode to return a string."""
try:
index = list(ids).index(tokenizer.EOS_ID)
return subtokenizer.decode(ids[:index])
except ValueError: # No EOS found in sequence
return subtokenizer.decode(ids)
def translate_file(
estimator, subtokenizer, input_file, output_file=None,
print_all_translations=True):
"""Translate lines in file, and save to output file if specified.
Args:
estimator: tf.Estimator used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and
translated lines.
input_file: file containing lines to translate
output_file: file that stores the generated translations.
print_all_translations: If true, all translations are printed to stdout.
Raises:
ValueError: if output file is invalid.
"""
batch_size = _DECODE_BATCH_SIZE
# Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order.
sorted_inputs, sorted_keys = _get_sorted_inputs(input_file)
num_decode_batches = (len(sorted_inputs) - 1) // batch_size + 1
def input_generator():
"""Yield encoded strings from sorted_inputs."""
for i, line in enumerate(sorted_inputs):
if i % batch_size == 0:
batch_num = (i // batch_size) + 1
print("Decoding batch %d out of %d." % (batch_num, num_decode_batches))
yield _encode_and_add_eos(line, subtokenizer)
def input_fn():
"""Created batched dataset of encoded inputs."""
ds = tf.data.Dataset.from_generator(
input_generator, tf.int64, tf.TensorShape([None]))
ds = ds.padded_batch(batch_size, [None])
return ds
translations = []
for i, prediction in enumerate(estimator.predict(input_fn)):
translation = _trim_and_decode(prediction["outputs"], subtokenizer)
translations.append(translation)
if print_all_translations:
print("Translating:")
print("\tInput: %s" % sorted_inputs[i])
print("\tOutput: %s\n" % translation)
print("=" * 100)
# Write translations in the order they appeared in the original file.
if output_file is not None:
if tf.gfile.IsDirectory(output_file):
raise ValueError("File output is a directory, will not save outputs to "
"file.")
tf.logging.info("Writing to file %s" % output_file)
with tf.gfile.Open(output_file, "w") as f:
for index in xrange(len(sorted_keys)):
f.write("%s\n" % translations[sorted_keys[index]])
def translate_text(estimator, subtokenizer, txt):
"""Translate a single string."""
encoded_txt = _encode_and_add_eos(txt, subtokenizer)
def input_fn():
ds = tf.data.Dataset.from_tensors(encoded_txt)
ds = ds.batch(_DECODE_BATCH_SIZE)
return ds
predictions = estimator.predict(input_fn)
translation = next(predictions)["outputs"]
translation = _trim_and_decode(translation, subtokenizer)
print("Translation of \"%s\": \"%s\"" % (txt, translation))
def main(unused_argv):
from official.transformer import transformer_main
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.text is None and FLAGS.file is None:
tf.logging.warn("Nothing to translate. Make sure to call this script using "
"flags --text or --file.")
return
subtokenizer = tokenizer.Subtokenizer(
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
if FLAGS.params == "base":
params = model_params.TransformerBaseParams
elif FLAGS.params == "big":
params = model_params.TransformerBigParams
else:
raise ValueError("Invalid parameter set defined: %s."
"Expected 'base' or 'big.'" % FLAGS.params)
# Set up estimator and params
params.beam_size = _BEAM_SIZE
params.alpha = _ALPHA
params.extra_decode_length = _EXTRA_DECODE_LENGTH
params.batch_size = _DECODE_BATCH_SIZE
estimator = tf.estimator.Estimator(
model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
params=params)
if FLAGS.text is not None:
tf.logging.info("Translating text: %s" % FLAGS.text)
translate_text(estimator, subtokenizer, FLAGS.text)
if FLAGS.file is not None:
input_file = os.path.abspath(FLAGS.file)
tf.logging.info("Translating file: %s" % input_file)
if not tf.gfile.Exists(FLAGS.file):
raise ValueError("File does not exist: %s" % input_file)
output_file = None
if FLAGS.file_out is not None:
output_file = os.path.abspath(FLAGS.file_out)
tf.logging.info("File output specified: %s" % output_file)
translate_file(estimator, subtokenizer, input_file, output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Model arguments
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/data/translate_ende",
help="[default: %(default)s] Directory where vocab file is stored.",
metavar="<DD>")
parser.add_argument(
"--vocab_file", "-vf", type=str, default=VOCAB_FILE,
help="[default: %(default)s] Name of vocabulary file.",
metavar="<vf>")
parser.add_argument(
"--model_dir", "-md", type=str, default="/tmp/transformer_model",
help="[default: %(default)s] Directory containing Transformer model "
"checkpoints.",
metavar="<MD>")
parser.add_argument(
"--params", "-p", type=str, default="big", choices=["base", "big"],
help="[default: %(default)s] Parameter used for trained model.",
metavar="<P>")
# Flags for specifying text/file to be translated.
parser.add_argument(
"--text", "-t", type=str, default=None,
help="[default: %(default)s] Text to translate. Output will be printed "
"to console.",
metavar="<T>")
parser.add_argument(
"--file", "-f", type=str, default=None,
help="[default: %(default)s] File containing text to translate. "
"Translation will be printed to console and, if --file_out is "
"provided, saved to an output file.",
metavar="<F>")
parser.add_argument(
"--file_out", "-fo", type=str, default=None,
help="[default: %(default)s] If --file flag is specified, save "
"translation to this file.",
metavar="<FO>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.
Two things to note in the pipeline:
1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{"inputs": [variable length array of integers],
"targets": [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between
"inputs" and "targets" length). Each group is then batched such that:
group_batch_size * length <= batch_size.
Another way to view batch_size is the maximum number of tokens in each batch.
Once batched, each element in the dataset will have the shape:
{"inputs": [group_batch_size, padded_input_length],
"targets": [group_batch_size, padded_target_length]}
Lengths are padded to the longest "inputs" or "targets" sequence in the batch
(padded_input_length and padded_target_length can be different).
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
2. Shuffling
While training, the dataset is shuffled in two places in the code. The first
is the list of training files. Second, while reading records using
`parallel_interleave`, the `sloppy` argument is used to generate randomness
in the order of the examples.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
# Use the number of training files as the shuffle buffer.
_FILE_SHUFFLE_BUFFER = 100
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER = 8 * 1000 * 1000
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1
def _load_records(filename):
"""Read file and return a dataset of tf.Examples."""
return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)
def _parse_example(serialized_example):
"""Return inputs and targets Tensors from a serialized tf.Example."""
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64)
}
parsed = tf.parse_single_example(serialized_example, data_fields)
inputs = tf.sparse_tensor_to_dense(parsed["inputs"])
targets = tf.sparse_tensor_to_dense(parsed["targets"])
return inputs, targets
def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example):
"""Returns the maximum length between the example inputs and targets."""
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
return length
def _create_min_max_boundaries(
max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE):
"""Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
returned values will be:
buckets_min = [0, 4, 8, 16, 24]
buckets_max = [4, 8, 16, 24, 25]
Args:
max_length: The maximum length of example in dataset.
min_boundary: Minimum length in boundary.
boundary_scale: Amount to scale consecutive boundaries in the list.
Returns:
min and max boundary lists
"""
# Create bucket boundaries list by scaling the previous boundary or adding 1
# (to ensure increasing boundary sizes).
bucket_boundaries = []
x = min_boundary
while x < max_length:
bucket_boundaries.append(x)
x = max(x + 1, int(x * boundary_scale))
# Create min and max boundary lists from the initial list.
buckets_min = [0] + bucket_boundaries
buckets_max = bucket_boundaries + [max_length + 1]
return buckets_min, buckets_max
def _batch_examples(dataset, batch_size, max_length):
"""Group examples by similar lengths, and return batched dataset.
Each batch of similar-length examples are padded to the same length, and may
have different number of elements in each batch, such that:
group_batch_size * padded_length <= batch_size.
This decreases the number of padding tokens per batch, which improves the
training speed.
Args:
dataset: Dataset of unbatched examples.
batch_size: Max number of tokens per batch of examples.
max_length: Max number of tokens in an example input or target sequence.
Returns:
Dataset of batched examples with similar lengths.
"""
# Get min and max boundary lists for each example. These are used to calculate
# the `bucket_id`, which is the index at which:
# buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
# Note that using both min and max lists improves the performance.
buckets_min, buckets_max = _create_min_max_boundaries(max_length)
# Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes = [batch_size // x for x in buckets_max]
# bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
def example_to_bucket_id(example_input, example_target):
"""Return int64 bucket id for this example, calculated based on length."""
seq_length = _get_example_length((example_input, example_target))
# TODO: investigate whether removing code branching improves performance.
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length),
tf.less(seq_length, buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
def window_size_fn(bucket_id):
"""Return number of examples to be grouped when given a bucket id."""
return bucket_batch_sizes[bucket_id]
def batching_fn(bucket_id, grouped_dataset):
"""Batch and add padding to a dataset of elements with similar lengths."""
bucket_batch_size = window_size_fn(bucket_id)
# Batch the dataset and add padding so that all input sequences in the
# examples have the same length, and all target sequences have the same
# lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.contrib.data.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
window_size_func=window_size_fn))
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_cpu_cores, shuffle, repeat):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per batch of examples
max_length: Maximum number of tokens per example
num_cpu_cores: Number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
Returns:
tf.data.Dataset object containing examples loaded from the files.
"""
dataset = tf.data.Dataset.list_files(file_pattern)
if shuffle:
# Shuffle filenames
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_cpu_cores))
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=num_cpu_cores)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
# Batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length)
dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(1)
return dataset
def train_input_fn(params):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*train*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=True, repeat=params.repeat_dataset)
def eval_input_fn(params):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(getattr(params, "data_dir", ""), "*dev*")
return _read_and_batch_from_files(
file_pattern, params.batch_size, params.max_length, params.num_cpu_cores,
shuffle=False, repeat=1)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
- Padded loss, accuracy, and negative log perplexity. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
- BLEU approximation. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
- ROUGE score. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
def _pad_tensors_to_same_length(x, y):
"""Pad x and y so that the results have the same length (second dimension)."""
with tf.name_scope("pad_to_same_length"):
x_length = tf.shape(x)[1]
y_length = tf.shape(y)[1]
max_length = tf.maximum(x_length, y_length)
x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
return x, y
def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
"""Calculate cross entropy loss while ignoring padding.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch_size, length_labels]
smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary
Returns:
Returns a float32 tensor with shape
[batch_size, max(length_logits, length_labels)]
"""
with tf.name_scope("loss", [logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
# Calculate smoothing cross entropy
with tf.name_scope("smoothing_cross_entropy", [logits, labels]):
confidence = 1.0 - smoothing
low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1)
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=soft_targets)
# Calculate the best (lowest) possible value of cross entropy, and
# subtract from the cross entropy loss.
normalizing_constant = -(
confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) *
low_confidence * tf.log(low_confidence + 1e-20))
xentropy -= normalizing_constant
weights = tf.to_float(tf.not_equal(labels, 0))
return xentropy * weights, weights
def _convert_to_eval_metric(metric_fn):
"""Wrap a metric fn that returns scores and weights as an eval metric fn.
The input metric_fn returns values for the current batch. The wrapper
aggregates the return values collected over all of the batches evaluated.
Args:
metric_fn: function that returns scores and weights for the current batch's
logits and predicted labels.
Returns:
function that aggregates the scores and weights from metric_fn.
"""
def problem_metric_fn(*args):
"""Returns an aggregation of the metric_fn's returned values."""
(scores, weights) = metric_fn(*args)
# The tf.metrics.mean function assures correct aggregation.
return tf.metrics.mean(scores, weights)
return problem_metric_fn
def get_eval_metrics(logits, labels, params):
"""Return dictionary of model evaluation metrics."""
metrics = {
"accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels),
"accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)(
logits, labels),
"accuracy_per_sequence": _convert_to_eval_metric(
padded_sequence_accuracy)(logits, labels),
"neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
logits, labels, params.vocab_size),
"approx_bleu_score": _convert_to_eval_metric(bleu_score)(logits, labels),
"rouge_2_fscore": _convert_to_eval_metric(rouge_2_fscore)(logits, labels),
"rouge_L_fscore": _convert_to_eval_metric(rouge_l_fscore)(logits, labels),
}
# Prefix each of the metric names with "metrics/". This allows the metric
# graphs to display under the "metrics" category in TensorBoard.
metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
return metrics
def padded_accuracy(logits, labels):
"""Percentage of times that predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
outputs = tf.to_int32(tf.argmax(logits, axis=-1))
padded_labels = tf.to_int32(labels)
return tf.to_float(tf.equal(outputs, padded_labels)), weights
def padded_accuracy_topk(logits, labels, k):
"""Percentage of times that top-k predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy_topk", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
effective_k = tf.minimum(k, tf.shape(logits)[-1])
_, outputs = tf.nn.top_k(logits, k=effective_k)
outputs = tf.to_int32(outputs)
padded_labels = tf.to_int32(labels)
padded_labels = tf.expand_dims(padded_labels, axis=-1)
padded_labels += tf.zeros_like(outputs) # Pad to same shape.
same = tf.to_float(tf.equal(outputs, padded_labels))
same_topk = tf.reduce_sum(same, axis=-1)
return same_topk, weights
def padded_accuracy_top5(logits, labels):
return padded_accuracy_topk(logits, labels, 5)
def padded_sequence_accuracy(logits, labels):
"""Percentage of times that predictions matches labels everywhere (non-0)."""
with tf.variable_scope("padded_sequence_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.to_float(tf.not_equal(labels, 0))
outputs = tf.to_int32(tf.argmax(logits, axis=-1))
padded_labels = tf.to_int32(labels)
not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights
axis = list(range(1, len(outputs.get_shape())))
correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
return correct_seq, tf.constant(1.0)
def padded_neg_log_perplexity(logits, labels, vocab_size):
"""Average log-perplexity excluding padding 0s. No smoothing."""
num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
return -num, den
def bleu_score(logits, labels):
"""Approximate BLEU score computation between labels and predictions.
An approximate BLEU scoring method since we do not glue word pieces or
decode the ids and tokenize the output. By default, we use ngram order of 4
and use brevity penalty. Also, this does not have beam search.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch-size, length_labels]
Returns:
bleu: int, approx bleu score
"""
predictions = tf.to_int32(tf.argmax(logits, axis=-1))
# TODO: Look into removing use of py_func
bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
return bleu, tf.constant(1.0)
def _get_ngrams_with_counter(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in xrange(1, max_order + 1):
for i in xrange(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
precisions = [0] * max_order
smooth = 1.0
for i in xrange(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
def rouge_2_fscore(logits, labels):
"""ROUGE-2 F1 score computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
logits: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge2_fscore: approx rouge-2 f1 score.
"""
predictions = tf.to_int32(tf.argmax(logits, axis=-1))
# TODO: Look into removing use of py_func
rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
return rouge_2_f_score, tf.constant(1.0)
def _get_ngrams(n, text):
"""Calculates n-grams.
Args:
n: which n-grams to calculate
text: An array of tokens
Returns:
A set of n-grams
"""
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def rouge_n(eval_sentences, ref_sentences, n=2):
"""Computes ROUGE-N f1 score of two text collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Args:
eval_sentences: Predicted sentences.
ref_sentences: Sentences from the reference set
n: Size of ngram. Defaults to 2.
Returns:
f1 score for ROUGE-N
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
eval_ngrams = _get_ngrams(n, eval_sentence)
ref_ngrams = _get_ngrams(n, ref_sentence)
ref_count = len(ref_ngrams)
eval_count = len(eval_ngrams)
# Count the overlapping ngrams between evaluated and reference
overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
overlapping_count = len(overlapping_ngrams)
# Handle edge case. This isn't mathematically correct, but it's good enough
if eval_count == 0:
precision = 0.0
else:
precision = float(overlapping_count) / eval_count
if ref_count == 0:
recall = 0.0
else:
recall = float(overlapping_count) / ref_count
f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))
# return overlapping_count / reference_count
return np.mean(f1_scores, dtype=np.float32)
def rouge_l_fscore(predictions, labels):
"""ROUGE scores computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
predictions: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge_l_fscore: approx rouge-l f1 score.
"""
outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels),
tf.float32)
return rouge_l_f_score, tf.constant(1.0)
def rouge_l_sentence_level(eval_sentences, ref_sentences):
"""Computes ROUGE-L (sentence level) of two collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Calculated according to:
R_lcs = LCS(X,Y)/m
P_lcs = LCS(X,Y)/n
F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
where:
X = reference summary
Y = Candidate summary
m = length of reference summary
n = length of candidate summary
Args:
eval_sentences: The sentences that have been picked by the summarizer
ref_sentences: The sentences from the reference set
Returns:
A float: F_lcs
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
m = float(len(ref_sentence))
n = float(len(eval_sentence))
lcs = _len_lcs(eval_sentence, ref_sentence)
f1_scores.append(_f_lcs(lcs, m, n))
return np.mean(f1_scores, dtype=np.float32)
def _len_lcs(x, y):
"""Returns the length of the Longest Common Subsequence between two seqs.
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: sequence of words
y: sequence of words
Returns
integer: Length of LCS between x and y
"""
table = _lcs(x, y)
n, m = len(x), len(y)
return table[n, m]
def _lcs(x, y):
"""Computes the length of the LCS between two seqs.
The implementation below uses a DP programming algorithm and runs
in O(nm) time where n = len(x) and m = len(y).
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: collection of words
y: collection of words
Returns:
Table of dictionary of coord and len lcs
"""
n, m = len(x), len(y)
table = dict()
for i in range(n + 1):
for j in range(m + 1):
if i == 0 or j == 0:
table[i, j] = 0
elif x[i - 1] == y[j - 1]:
table[i, j] = table[i - 1, j - 1] + 1
else:
table[i, j] = max(table[i - 1, j], table[i, j - 1])
return table
def _f_lcs(llcs, m, n):
"""Computes the LCS-based F-measure score.
Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
rouge-working-note-v1.3.1.pdf
Args:
llcs: Length of LCS
m: number of words in reference summary
n: number of words in candidate summary
Returns:
Float. LCS-based F-measure score
"""
r_lcs = llcs / m
p_lcs = llcs / n
beta = p_lcs / (r_lcs + 1e-12)
num = (1 + (beta ** 2)) * r_lcs * p_lcs
denom = r_lcs + ((beta ** 2) * p_lcs)
f_lcs = num / (denom + 1e-12)
return f_lcs
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Subtokenizer class to encode and decode strings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import sys
import unicodedata
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
PAD = "<pad>"
PAD_ID = 0
EOS = "<EOS>"
EOS_ID = 1
RESERVED_TOKENS = [PAD, EOS]
# Set of characters that will be used in the function _escape_token() (see func
# docstring for more details).
# This set is added to the alphabet list to ensure that all escaped tokens can
# be encoded.
_ESCAPE_CHARS = set(u"\\_u;0123456789")
# Regex for the function _unescape_token(), the inverse of _escape_token().
# This is used to find "\u", "\\", and "\###;" substrings in the token.
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_UNDEFINED_UNICODE = u"\u3013"
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
# min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size.
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
class Subtokenizer(object):
"""Encodes and decodes strings to/from integer IDs."""
def __init__(self, vocab_file, reserved_tokens=None):
"""Initializes class, creating a vocab file if data_files is provided."""
tf.logging.info("Initializing Subtokenizer from file %s." % vocab_file)
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens)
self.alphabet = _generate_alphabet_dict(self.subtoken_list)
self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list)
self.max_subtoken_length = 0
for subtoken in self.subtoken_list:
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
# Create cache to speed up subtokenization
self._cache_size = 2 ** 20
self._cache = [(None, None)] * self._cache_size
@staticmethod
def init_from_files(
vocab_file, files, target_vocab_size, threshold, min_count=None,
file_byte_limit=1e6, reserved_tokens=None):
"""Create subtoken vocabulary based on files, and save vocab to file.
Args:
vocab_file: String name of vocab file to store subtoken vocabulary.
files: List of file paths that will be used to generate vocabulary.
target_vocab_size: target vocabulary size to generate.
threshold: int threshold of vocabulary size to accept.
min_count: int minimum count to use for generating the vocabulary. The min
count is the minimum number of times a subtoken should appear in the
files before it is added to the vocabulary. If set to none, this value
is found using binary search.
file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that
will be drawn from the files.
reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list.
Returns:
Subtokenizer object
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if tf.gfile.Exists(vocab_file):
tf.logging.info("Vocab file already exists (%s)" % vocab_file)
else:
tf.logging.info("Begin steps to create subtoken vocabulary...")
token_counts = _count_tokens(files, file_byte_limit)
alphabet = _generate_alphabet_dict(token_counts)
subtoken_list = _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_vocab_size, threshold, min_count,
reserved_tokens)
tf.logging.info("Generated vocabulary with %d subtokens." %
len(subtoken_list))
_save_vocab_file(vocab_file, subtoken_list)
return Subtokenizer(vocab_file)
def encode(self, raw_string, add_eos=False):
"""Encodes a string into a list of int subtoken ids."""
ret = []
tokens = _split_string_to_tokens(_native_to_unicode(raw_string))
for token in tokens:
ret.extend(self._token_to_subtoken_ids(token))
if add_eos:
ret.append(EOS_ID)
return ret
def _token_to_subtoken_ids(self, token):
"""Encode a single token into a list of subtoken ids."""
cache_location = hash(token) % self._cache_size
cache_key, cache_value = self._cache[cache_location]
if cache_key == token:
return cache_value
ret = _split_token_to_subtokens(
_escape_token(token, self.alphabet), self.subtoken_to_id_dict,
self.max_subtoken_length)
ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret]
self._cache[cache_location] = (token, ret)
return ret
def decode(self, subtokens):
"""Converts list of int subtokens ids into a string."""
if isinstance(subtokens, np.ndarray):
# Note that list(subtokens) converts subtokens to a python list, but the
# items remain as np.int32. This converts both the array and its items.
subtokens = subtokens.tolist()
if not subtokens:
return ""
assert isinstance(subtokens, list) and isinstance(subtokens[0], int), (
"Subtokens argument passed into decode() must be a list of integers.")
return _unicode_to_native(
_join_tokens_to_string(self._subtoken_ids_to_tokens(subtokens)))
def _subtoken_ids_to_tokens(self, subtokens):
"""Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens = "".join([
self.subtoken_list[s] for s in subtokens
if s < len(self.subtoken_list)])
escaped_tokens = escaped_tokens.split("_")
# All tokens in the vocabulary list have been escaped (see _escape_token())
# so each token must be unescaped when decoding.
ret = []
for token in escaped_tokens:
if token:
ret.append(_unescape_token(token))
return ret
def _save_vocab_file(vocab_file, subtoken_list):
"""Save subtokens to file."""
with tf.gfile.Open(vocab_file, mode="w") as f:
for subtoken in subtoken_list:
f.write("'%s'\n" % _unicode_to_native(subtoken))
def _load_vocab_file(vocab_file, reserved_tokens=None):
"""Load vocabulary while ensuring reserved tokens are at the top."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
subtoken_list = []
with tf.gfile.Open(vocab_file, mode="r") as f:
for line in f:
subtoken = _native_to_unicode(line.strip())
subtoken = subtoken[1:-1] # Remove surrounding single-quotes
if subtoken in reserved_tokens:
continue
subtoken_list.append(_native_to_unicode(subtoken))
return reserved_tokens + subtoken_list
def _native_to_unicode(s):
"""Convert string to unicode (required in Python 2)."""
if six.PY2:
return s if isinstance(s, unicode) else s.decode("utf-8")
else:
return s
def _unicode_to_native(s):
"""Convert string from unicode to native format (required in Python 2)."""
if six.PY2:
return s.encode("utf-8") if isinstance(s, unicode) else s
else:
return s
def _split_string_to_tokens(text):
"""Splits text to a list of string tokens."""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
for pos in xrange(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
token_start = pos
final_token = text[token_start:]
ret.append(final_token)
return ret
def _join_tokens_to_string(tokens):
"""Join a list of string tokens into a single string."""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
def _escape_token(token, alphabet):
r"""Replace characters that aren't in the alphabet and append "_" to token.
Apply three transformations to the token:
1. Replace underline character "_" with "\u", and backslash "\" with "\\".
2. Replace characters outside of the alphabet with "\###;", where ### is the
character's Unicode code point.
3. Appends "_" to mark the end of a token.
Args:
token: unicode string to be escaped
alphabet: list of all known characters
Returns:
escaped string
"""
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
return u"".join(ret) + "_"
def _unescape_token(token):
r"""Replaces escaped characters in the token with their unescaped versions.
Applies inverse transformations as _escape_token():
1. Replace "\u" with "_", and "\\" with "\".
2. Replace "\###;" with the unicode character the ### refers to.
Args:
token: escaped string
Returns:
unescaped string
"""
def match(m):
r"""Returns replacement string for matched object.
Matched objects contain one of the strings that matches the regex pattern:
r"\\u|\\\\|\\([0-9]+);"
The strings can be '\u', '\\', or '\###;' (### is any digit number).
m.group(0) refers to the entire matched string ('\u', '\\', or '\###;').
m.group(1) refers to the first parenthesized subgroup ('###').
m.group(0) exists for all match objects, while m.group(1) exists only for
the string '\###;'.
This function looks to see if m.group(1) exists. If it doesn't, then the
matched string must be '\u' or '\\' . In this case, the corresponding
replacement ('_' and '\') are returned. Note that in python, a single
backslash is written as '\\', and double backslash as '\\\\'.
If m.goup(1) exists, then use the integer in m.group(1) to return a
unicode character.
Args:
m: match object
Returns:
String to replace matched object with.
"""
# Check if the matched strings are '\u' or '\\'.
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
# If m.group(1) exists, try and return unicode character.
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return _UNDEFINED_UNICODE
# Use match function to replace escaped substrings in the token.
return _UNESCAPE_REGEX.sub(match, token)
def _count_tokens(files, file_byte_limit=1e6):
"""Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear
in the samples. The samples are semi-evenly distributed across the file.
Args:
files: List of filepaths
file_byte_limit: Max number of bytes that will be read from each file.
Returns:
Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files.
"""
token_counts = collections.defaultdict(int)
for filepath in files:
with tf.gfile.Open(filepath, mode="r") as reader:
file_byte_budget = file_byte_limit
counter = 0
lines_to_skip = int(reader.size() / (file_byte_budget * 2))
for line in reader:
if counter < lines_to_skip:
counter += 1
else:
if file_byte_budget < 0:
break
line = line.strip()
file_byte_budget -= len(line)
counter = 0
# Add words to token counts
for token in _split_string_to_tokens(_native_to_unicode(line)):
token_counts[token] += 1
return token_counts
def _list_to_index_dict(lst):
"""Create dictionary mapping list items to their indices in the list."""
return {item: n for n, item in enumerate(lst)}
def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
"""Splits a token into subtokens defined in the subtoken dict."""
ret = []
start = 0
token_len = len(token)
while start < token_len:
# Find the longest subtoken, so iterate backwards.
for end in xrange(min(token_len, start + max_subtoken_length), start, -1):
subtoken = token[start:end]
if subtoken in subtoken_dict:
ret.append(subtoken)
start = end
break
else: # Did not break
# If there is no possible encoding of the escaped token then one of the
# characters in the token is not in the alphabet. This should be
# impossible and would be indicative of a bug.
raise ValueError("Was unable to split token \"%s\" into subtokens." %
token)
return ret
def _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_size, threshold, min_count=None,
reserved_tokens=None):
"""Generate subtoken vocabulary close to the target size."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if min_count is not None:
tf.logging.info("Using min_count=%d to generate vocab with target size %d" %
(min_count, target_size))
return _generate_subtokens(
token_counts, alphabet, min_count, reserved_tokens=reserved_tokens)
def bisect(min_val, max_val):
"""Recursive function to binary search for subtoken vocabulary."""
cur_count = (min_val + max_val) // 2
tf.logging.info("Binary search: trying min_count=%d (%d %d)" %
(cur_count, min_val, max_val))
subtoken_list = _generate_subtokens(
token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens)
val = len(subtoken_list)
tf.logging.info("Binary search: min_count=%d resulted in %d tokens" %
(cur_count, val))
within_threshold = abs(val - target_size) < threshold
if within_threshold or min_val >= max_val or cur_count < 2:
return subtoken_list
if val > target_size:
other_subtoken_list = bisect(cur_count + 1, max_val)
else:
other_subtoken_list = bisect(min_val, cur_count - 1)
# Return vocabulary dictionary with the closest number of tokens.
other_val = len(other_subtoken_list)
if abs(other_val - target_size) < abs(val - target_size):
return other_subtoken_list
return subtoken_list
tf.logging.info("Finding best min_count to get target size of %d" %
target_size)
return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT)
def _generate_alphabet_dict(iterable, reserved_tokens=None):
"""Create set of characters that appear in any element in the iterable."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
alphabet = {c for token in iterable for c in token}
alphabet |= {c for token in reserved_tokens for c in token}
alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set.
return alphabet
def _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length):
"""Count number of times subtokens appear, and generate new subtokens.
Args:
token_counts: dict mapping tokens to the number of times they appear in the
original files.
alphabet: list of allowed characters. Used to escape the tokens, which
guarantees that all tokens can be split into subtokens.
subtoken_dict: dict mapping subtokens to ids.
max_subtoken_length: maximum length of subtoken in subtoken_dict.
Returns:
A defaultdict mapping subtokens to the number of times they appear in the
tokens. The dict may contain new subtokens.
"""
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
token = _escape_token(token, alphabet)
subtokens = _split_token_to_subtokens(
token, subtoken_dict, max_subtoken_length)
# Generate new subtokens by taking substrings from token.
start = 0
for subtoken in subtokens:
for end in xrange(start + 1, len(token) + 1):
new_subtoken = token[start:end]
subtoken_counts[new_subtoken] += count
start += len(subtoken)
return subtoken_counts
def _filter_and_bucket_subtokens(subtoken_counts, min_count):
"""Return a bucketed list of subtokens that are filtered by count.
Args:
subtoken_counts: defaultdict mapping subtokens to their counts
min_count: int count used to filter subtokens
Returns:
List of subtoken sets, where subtokens in set i have the same length=i.
"""
# Create list of buckets, where subtokens in bucket i have length i.
subtoken_buckets = []
for subtoken, count in six.iteritems(subtoken_counts):
if count < min_count: # Filter out subtokens that don't appear enough
continue
while len(subtoken_buckets) <= len(subtoken):
subtoken_buckets.append(set())
subtoken_buckets[len(subtoken)].add(subtoken)
return subtoken_buckets
def _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens=None):
"""Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens
first). When a subtoken is added, the counts of each of its prefixes are
decreased. Prefixes that don't appear much outside the subtoken are not added
to the candidate list.
For example:
subtoken being added to candidate list: 'translate'
subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...}
min_count: 5
When 'translate' is added, subtoken_counts is updated to:
{'translate':0, 't':30, 'tr':6, 'tra': 2, ...}
The subtoken 'tra' will not be added to the candidate list, because it appears
twice (less than min_count) outside of 'translate'.
Args:
subtoken_counts: defaultdict mapping str subtokens to int counts
min_count: int minumum count requirement for subtokens
alphabet: set of characters. Each character is added to the subtoken list to
guarantee that all tokens can be encoded.
reserved_tokens: list of tokens that will be added to the beginning of the
returned subtoken list.
Returns:
List of candidate subtokens in decreasing count order, and maximum subtoken
length
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Create a list of (count, subtoken) for each candidate subtoken.
subtoken_candidates = []
# Use bucketted list to iterate through subtokens in order of length.
# subtoken_buckets[i] = set(subtokens), where each subtoken has length i.
subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count)
max_subtoken_length = len(subtoken_buckets) - 1
# Go through the list in reverse order to consider longer subtokens first.
for subtoken_len in xrange(max_subtoken_length, 0, -1):
for subtoken in subtoken_buckets[subtoken_len]:
count = subtoken_counts[subtoken]
# Possible if this subtoken is a prefix of another token.
if count < min_count:
continue
# Ignore alphabet/reserved tokens, which will be added manually later.
if subtoken not in alphabet and subtoken not in reserved_tokens:
subtoken_candidates.append((count, subtoken))
# Decrement count of the subtoken's prefixes (if a longer subtoken is
# added, its prefixes lose priority to be added).
for end in xrange(1, subtoken_len):
subtoken_counts[subtoken[:end]] -= count
# Add alphabet subtokens (guarantees that all strings are encodable).
subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet)
# Order subtoken candidates by decreasing count.
subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)]
# Add reserved tokens to beginning of the list.
subtoken_list = reserved_tokens + subtoken_list
return subtoken_list, max_subtoken_length
def _generate_subtokens(
token_counts, alphabet, min_count, num_iterations=4,
reserved_tokens=None):
"""Create a list of subtokens in decreasing order of frequency.
Args:
token_counts: dict mapping str tokens -> int count
alphabet: set of characters
min_count: int minimum number of times a subtoken must appear before it is
added to the vocabulary.
num_iterations: int number of iterations to generate new tokens.
reserved_tokens: list of tokens that will be added to the beginning to the
returned subtoken list.
Returns:
Sorted list of subtokens (most frequent first)
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Use alphabet set to create initial list of subtokens
subtoken_list = reserved_tokens + list(alphabet)
max_subtoken_length = 1
# On each iteration, segment all words using the subtokens defined in
# subtoken_dict, count how often the resulting subtokens appear, and update
# the dictionary with subtokens w/ high enough counts.
for i in xrange(num_iterations):
tf.logging.info("\tGenerating subtokens: iteration %d" % i)
# Generate new subtoken->id dictionary using the new subtoken list.
subtoken_dict = _list_to_index_dict(subtoken_list)
# Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens.
subtoken_counts = _count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
# Generate new list of subtokens sorted by subtoken count.
subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
tf.logging.info("\tVocab size: %d" % len(subtoken_list))
return subtoken_list
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test Subtokenizer and string helper methods."""
import collections
import tempfile
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer.utils import tokenizer
class SubtokenizerTest(unittest.TestCase):
def _init_subtokenizer(self, vocab_list):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w:
for subtoken in vocab_list:
w.write("'%s'" % subtoken)
w.write("\n")
return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[])
def test_encode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
s = "testing 123"
encoded_list = subtokenizer.encode(s)
self.assertEqual([1, 2, 0], encoded_list)
def test_decode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
decoded_str = subtokenizer.decode(encoded_list)
self.assertEqual("testing 123", decoded_str)
def test_subtoken_ids_to_tokens(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list)
self.assertEqual([u"testing", u"123"], token_list)
class StringHelperTest(unittest.TestCase):
def test_split_string_to_tokens(self):
text = "test? testing 123."
tokens = tokenizer._split_string_to_tokens(text)
self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
def test_join_tokens_to_string(self):
tokens = ["test", "? ", "testing", "123", "."]
s = tokenizer._join_tokens_to_string(tokens)
self.assertEqual("test? testing 123.", s)
def test_escape_token(self):
token = u"abc_\\4"
alphabet = set("abc_\\u;")
escaped_token = tokenizer._escape_token(token, alphabet)
self.assertEqual("abc\\u\\\\\\52;_", escaped_token)
def test_unescape_token(self):
escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
unescaped_token = tokenizer._unescape_token(escaped_token)
self.assertEqual(
"Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
def test_list_to_index_dict(self):
lst = ["test", "strings"]
d = tokenizer._list_to_index_dict(lst)
self.assertDictEqual({"test": 0, "strings": 1}, d)
def test_split_token_to_subtokens(self):
token = "abc"
subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
max_subtoken_length = 2
subtokens = tokenizer._split_token_to_subtokens(
token, subtoken_dict, max_subtoken_length)
self.assertEqual(["ab", "c"], subtokens)
def test_generate_alphabet_dict(self):
s = ["testing", "123"]
reserved_tokens = ["???"]
alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens)
self.assertIn("?", alphabet)
self.assertIn("t", alphabet)
self.assertIn("e", alphabet)
self.assertIn("s", alphabet)
self.assertIn("i", alphabet)
self.assertIn("n", alphabet)
self.assertIn("g", alphabet)
self.assertIn("1", alphabet)
self.assertIn("2", alphabet)
self.assertIn("3", alphabet)
def test_count_and_gen_subtokens(self):
token_counts = {"abc": 5}
alphabet = set("abc_")
subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3}
max_subtoken_length = 2
subtoken_counts = tokenizer._count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
self.assertIsInstance(subtoken_counts, collections.defaultdict)
self.assertDictEqual(
{"a": 5, "b": 5, "c": 5, "_": 5, "ab": 5, "bc": 5, "c_": 5,
"abc": 5, "bc_": 5, "abc_": 5}, subtoken_counts)
def test_filter_and_bucket_subtokens(self):
subtoken_counts = collections.defaultdict(
int, {"a": 2, "b": 4, "c": 1, "ab": 6, "ac": 3, "abbc": 5})
min_count = 3
subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
subtoken_counts, min_count)
self.assertEqual(len(subtoken_buckets[0]), 0)
self.assertEqual(set("b"), subtoken_buckets[1])
self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2])
self.assertEqual(len(subtoken_buckets[3]), 0)
self.assertEqual(set(["abbc"]), subtoken_buckets[4])
def test_gen_new_subtoken_list(self):
subtoken_counts = collections.defaultdict(
int, {"translate": 10, "t": 40, "tr": 16, "tra": 12})
min_count = 5
alphabet = set("translate")
reserved_tokens = ["reserved", "tokens"]
subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
# Check that "tra" isn"t in the list (its count should be decremented to 2,
# so it should not be added to the canddiate list).
self.assertNotIn("tra", subtoken_list)
self.assertIn("tr", subtoken_list)
self.assertIn("t", subtoken_list)
self.assertEqual(len("translate"), max_token_length)
def test_generate_subtokens(self):
token_counts = {"ab": 1, "bc": 3, "abc": 5}
alphabet = set("abc_")
min_count = 100
num_iterations = 1
reserved_tokens = ["reserved", "tokens"]
vocab_list = tokenizer._generate_subtokens(
token_counts, alphabet, min_count, num_iterations, reserved_tokens)
# Check that reserved tokens are at the front of the list
self.assertEqual(vocab_list[:2], reserved_tokens)
# Check that each character in alphabet is in the vocab list
for c in alphabet:
self.assertIn(c, vocab_list)
if __name__ == "__main__":
unittest.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Collection of parsers which are shared among the official models.
The parsers in this module are intended to be used as parents to all arg
parsers in official models. For instance, one might define a new class:
class ExampleParser(argparse.ArgumentParser):
def __init__(self):
super(ExampleParser, self).__init__(parents=[
arg_parsers.LocationParser(data_dir=True, model_dir=True),
arg_parsers.DummyParser(use_synthetic_data=True),
])
self.add_argument(
"--application_specific_arg", "-asa", type=int, default=123,
help="[default: %(default)s] This arg is application specific.",
metavar="<ASA>"
)
Notes about add_argument():
Argparse will automatically template in default values in help messages if
the "%(default)s" string appears in the message. Using the example above:
parser = ExampleParser()
parser.set_defaults(application_specific_arg=3141592)
parser.parse_args(["-h"])
When the help text is generated, it will display 3141592 to the user. (Even
though the default was 123 when the flag was created.)
The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than:
--app_specific_arg APP_SPECIFIC_ARG, -asa APP_SPECIFIC_ARG
if metavar="<ASA>" is set, the user sees:
--app_specific_arg <ASA>, -asa <ASA>
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import tensorflow as tf
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def parse_dtype_info(flags):
"""Convert dtype string to tf dtype, and set loss_scale default as needed.
Args:
flags: namespace object returned by arg parser.
Raises:
ValueError: If an invalid dtype is provided.
"""
if flags.dtype in (i[0] for i in DTYPE_MAP.values()):
return # Make function idempotent
try:
flags.dtype, default_loss_scale = DTYPE_MAP[flags.dtype]
except KeyError:
raise ValueError("Invalid dtype: {}".format(flags.dtype))
flags.loss_scale = flags.loss_scale or default_loss_scale
class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True, export_dir=True):
super(BaseParser, self).__init__(add_help=add_help)
if data_dir:
self.add_argument(
"--data_dir", "-dd", default="/tmp",
help="[default: %(default)s] The location of the input data.",
metavar="<DD>",
)
if model_dir:
self.add_argument(
"--model_dir", "-md", default="/tmp",
help="[default: %(default)s] The location of the model checkpoint "
"files.",
metavar="<MD>",
)
if train_epochs:
self.add_argument(
"--train_epochs", "-te", type=int, default=1,
help="[default: %(default)s] The number of epochs used to train.",
metavar="<TE>"
)
if epochs_between_evals:
self.add_argument(
"--epochs_between_evals", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<EBE>"
)
if stop_threshold:
self.add_argument(
"--stop_threshold", "-st", type=float, default=None,
help="[default: %(default)s] If passed, training will stop at "
"the earlier of train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold.",
metavar="<ST>"
)
if batch_size:
self.add_argument(
"--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.",
metavar="<BS>"
)
if multi_gpu:
self.add_argument(
"--multi_gpu", action="store_true",
help="If set, run across all available GPUs."
)
if hooks:
self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
help="[default: %(default)s] A list of strings to specify the names "
"of train hooks. "
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook."
"See official.utils.logs.hooks_helper for details.",
metavar="<HK>"
)
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
"""
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True,
dtype=True):
super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls:
self.add_argument(
"--num_parallel_calls", "-npc",
type=int, default=5,
help="[default: %(default)s] The number of records that are "
"processed in parallel during input processing. This can be "
"optimized per data set but for generally homogeneous data "
"sets, should be approximately the number of available CPU "
"cores.",
metavar="<NPC>"
)
if inter_op:
self.add_argument(
"--inter_op_parallelism_threads", "-inter",
type=int, default=0,
help="[default: %(default)s Number of inter_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTER>"
)
if intra_op:
self.add_argument(
"--intra_op_parallelism_threads", "-intra",
type=int, default=0,
help="[default: %(default)s Number of intra_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTRA>"
)
if use_synthetic_data:
self.add_argument(
"--use_synthetic_data", "-synth",
action="store_true",
help="If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
)
if max_train_steps:
self.add_argument(
"--max_train_steps", "-mts", type=int, default=None,
help="[default: %(default)s] The model will stop training if the "
"global_step reaches this value. If not set, training will run"
"until the specified number of epochs have run as usual. It is"
"generally recommended to set --train_epochs=1 when using this"
"flag.",
metavar="<MTS>"
)
if dtype:
self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=list(DTYPE_MAP.keys()),
help="[default: %(default)s] {%(choices)s} The TensorFlow datatype "
"used for calculations. Variables may be cast to a higher"
"precision on a case-by-case basis for numerical stability.",
metavar="<DT>"
)
self.add_argument(
"--loss_scale", "-ls",
type=int,
help="[default: %(default)s] The amount to scale the loss by when "
"the model is run. Before gradients are computed, the loss is "
"multiplied by the loss scale, making all gradients loss_scale "
"times larger. To adjust for this, gradients are divided by the "
"loss scale before being applied to variables. This is "
"mathematically equivalent to training without a loss scale, "
"but the loss scale helps avoid some intermediate gradients "
"from underflowing to zero. If not provided the default for "
"fp16 is 128 and 1 for all other dtypes.",
)
class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_format: Create a flag to specify image axis convention.
"""
def __init__(self, add_help=False, data_format=True):
super(ImageModelParser, self).__init__(add_help=add_help)
if data_format:
self.add_argument(
"--data_format", "-df",
default=None,
choices=["channels_first", "channels_last"],
help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.",
metavar="<CF>"
)
class BenchmarkParser(argparse.ArgumentParser):
"""Default parser for benchmark logging.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
benchmark_log_dir: Create a flag to specify location for benchmark logging.
"""
def __init__(self, add_help=False, benchmark_log_dir=True,
bigquery_uploader=True):
super(BenchmarkParser, self).__init__(add_help=add_help)
if benchmark_log_dir:
self.add_argument(
"--benchmark_log_dir", "-bld", default=None,
help="[default: %(default)s] The location of the benchmark logging.",
metavar="<BLD>"
)
if bigquery_uploader:
self.add_argument(
"--gcp_project", "-gp", default=None,
help="[default: %(default)s] The GCP project name where the benchmark"
" will be uploaded.",
metavar="<GP>"
)
self.add_argument(
"--bigquery_data_set", "-bds", default="test_benchmark",
help="[default: %(default)s] The Bigquery dataset name where the"
" benchmark will be uploaded.",
metavar="<BDS>"
)
self.add_argument(
"--bigquery_run_table", "-brt", default="benchmark_run",
help="[default: %(default)s] The Bigquery table name where the"
" benchmark run information will be uploaded.",
metavar="<BRT>"
)
self.add_argument(
"--bigquery_metric_table", "-bmt", default="benchmark_metric",
help="[default: %(default)s] The Bigquery table name where the"
" benchmark metric information will be uploaded.",
metavar="<BMT>"
)
class EagerParser(BaseParser):
"""Remove options not relevant for Eager from the BaseParser."""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, batch_size=True):
super(EagerParser, self).__init__(
add_help=add_help, data_dir=data_dir, model_dir=model_dir,
train_epochs=train_epochs, epochs_between_evals=False,
stop_threshold=False, batch_size=batch_size, multi_gpu=False,
hooks=False)
# Adding Abseil (absl) flags quickstart
## Defining a flag
absl flag definitions are similar to argparse, although they are defined on a global namespace.
For instance defining a string flag looks like:
```$xslt
from absl import flags
flags.DEFINE_string(
name="my_flag",
default="a_sensible_default",
help="Here is what this flag does."
)
```
All three arguments are required, but default may be `None`. A common optional argument is
short_name for defining abreviations. Certain `DEFINE_*` methods will have other required arguments.
For instance `DEFINE_enum` requires the `enum_values` argument to be specified.
## Key Flags
absl has the concept of a key flag. Any flag defined in `__main__` is considered a key flag by
default. Key flags are displayed in `--help`, others only appear in `--helpfull`. In order to
handle key flags that are defined outside the module in question, absl provides the
`flags.adopt_module_key_flags()` method. This adds the key flags of a different module to one's own
key flags. For example:
```$xslt
File: flag_source.py
---------------------------------------
from absl import flags
flags.DEFINE_string(name="my_flag", default="abc", help="a flag.")
```
```$xslt
File: my_module.py
---------------------------------------
from absl import app as absl_app
from absl import flags
import flag_source
flags.adopt_module_key_flags(flag_source)
def main(_):
pass
absl_app.run(main, [__file__, "-h"]
```
when `my_module.py` is run it will show the help text for `my_flag`. Because not all flags defined
in a file are equally important, `official/utils/flags/core.py` (generally imported as flags_core)
provides an abstraction for handling key flag declaration in an easy way through the
`register_key_flags_in_core()` function, which allows a module to make a single
`adopt_key_flags(flags_core)` call when using the util flag declaration functions.
## Validators
Often the constraints on a flag are complicated. absl provides the validator decorator to allow
one to mark a function as a flag validation function. Suppose we want users to provide a flag
which is a palindrome.
```$xslt
from absl import flags
flags.DEFINE_string(name="pal_flag", short_name="pf", default="", help="Give me a palindrome")
@flags.validator("pal_flag")
def _check_pal(provided_pal_flag):
return provided_pal_flag == provided_pal_flag[::-1]
```
Validators take the form that returning True (truthy) passes, and all others
(False, None, exception) fail.
## Common Flags
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
and channeled through `official.utils.flags.core`. For instance to define common supervised
learning parameters one could use the following code:
```$xslt
from absl import app as absl_app
from absl import flags
from official.utils.flags import core as flags_core
def define_flags():
flags_core.define_base()
flags.adopt_key_flags(flags_core)
def main(flags_obj):
pass
if __name__ == "__main__"
absl_app.run(main)
```
## Testing
To test using absl, simply declare flags in the setupClass method of TensorFlow's TestCase.
```$xslt
from absl import flags
import tensorflow as tf
def define_flags():
flags.DEFINE_string(name="test_flag", default="abc", help="an example flag")
class BaseTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(BaseTester, cls).setUpClass()
define_flags()
def test_trivial(self):
flags_core.parse_flags([__file__, "test_flag", "def"])
self.AssertEqual(flags.FLAGS.test_flag, "def")
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment