Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -29,8 +29,10 @@ _NEG_INF_FP32 = -1e9
_NEG_INF_FP16 = np.finfo(np.float16).min
def get_position_encoding(
length, hidden_size, min_timescale=1.0, max_timescale=1.0e4):
def get_position_encoding(length,
hidden_size,
min_timescale=1.0,
max_timescale=1.0e4):
"""Return positional encoding.
Calculates the position encoding as a mix of sine and cosine functions with
......@@ -77,8 +79,8 @@ def get_decoder_self_attention_bias(length, dtype=tf.float32):
"""
neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.linalg.band_part(tf.ones([length, length], dtype=dtype),
-1, 0)
valid_locs = tf.linalg.band_part(
tf.ones([length, length], dtype=dtype), -1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = neg_inf * (1.0 - valid_locs)
return decoder_bias
......
......@@ -40,22 +40,19 @@ class ModelUtilsTest(tf.test.TestCase):
bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5])
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[NEG_INF, 0, 0, NEG_INF, 0]],
flattened_bias)
self.assertAllEqual(
[[0, NEG_INF, NEG_INF, NEG_INF, 0], [0, 0, NEG_INF, NEG_INF, NEG_INF],
[NEG_INF, 0, 0, NEG_INF, 0]], flattened_bias)
self.assertAllEqual([3, 1, 1, 5], bias_shape)
def test_get_decoder_self_attention_bias(self):
length = 5
bias = model_utils.get_decoder_self_attention_bias(length)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF],
[0, 0, 0, NEG_INF, NEG_INF],
[0, 0, 0, 0, NEG_INF],
[0, 0, 0, 0, 0]]]],
bias)
self.assertAllEqual(
[[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, 0, NEG_INF, NEG_INF],
[0, 0, 0, 0, NEG_INF], [0, 0, 0, 0, 0]]]], bias)
if __name__ == "__main__":
......
......@@ -31,7 +31,6 @@ from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils
from official.nlp.transformer.utils.tokenizer import EOS_ID
# Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are.
# pylint: disable=not-callable
......@@ -49,11 +48,12 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype=tf.float32)(logits)
logits = tf.keras.layers.Lambda(
lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size)
loss = metrics.transformer_loss(logits, targets, label_smoothing,
vocab_size)
model.add_loss(loss)
return model
......@@ -130,9 +130,7 @@ class Transformer(tf.keras.Model):
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs.set_shape([
decode_batch_size, self.params["decode_max_length"]
])
inputs.set_shape([decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
......@@ -314,15 +312,13 @@ class Transformer(tf.keras.Model):
cache = {
"layer_%d" % layer: {
"k":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"]),
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"])
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
}
# pylint: enable=g-complex-comprehension
......@@ -512,15 +508,14 @@ class DecoderStack(tf.keras.layers.Layer):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: A tensor with shape
[batch_size, target_length, hidden_size].
encoder_outputs: A tensor with shape
[batch_size, input_length, hidden_size]
decoder_self_attention_bias: A tensor with shape
[1, 1, target_len, target_length], the bias for decoder self-attention
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
decoder_inputs: A tensor with shape [batch_size, target_length,
hidden_size].
encoder_outputs: A tensor with shape [batch_size, input_length,
hidden_size]
decoder_self_attention_bias: A tensor with shape [1, 1, target_len,
target_length], the bias for decoder self-attention layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length], the
bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
......
......@@ -34,11 +34,12 @@ class TransformerLayersTest(tf.test.TestCase):
dropout = 0.5
dim_per_head = hidden_size // num_heads
layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout)
self.assertDictEqual(layer.get_config(), {
"hidden_size": hidden_size,
"num_heads": num_heads,
"attention_dropout": dropout,
})
self.assertDictEqual(
layer.get_config(), {
"hidden_size": hidden_size,
"num_heads": num_heads,
"attention_dropout": dropout,
})
length = 2
x = tf.ones([1, length, hidden_size])
bias = tf.ones([1])
......@@ -47,9 +48,23 @@ class TransformerLayersTest(tf.test.TestCase):
"v": tf.zeros([1, 0, num_heads, dim_per_head]),
}
y = layer(x, bias, training=True, cache=cache)
self.assertEqual(y.shape, (1, length, 64,))
self.assertEqual(cache["k"].shape, (1, length, num_heads, dim_per_head,))
self.assertEqual(cache["v"].shape, (1, length, num_heads, dim_per_head,))
self.assertEqual(y.shape, (
1,
length,
64,
))
self.assertEqual(cache["k"].shape, (
1,
length,
num_heads,
dim_per_head,
))
self.assertEqual(cache["v"].shape, (
1,
length,
num_heads,
dim_per_head,
))
def test_embedding_shared_weights(self):
vocab_size = 50
......@@ -63,25 +78,38 @@ class TransformerLayersTest(tf.test.TestCase):
idx = tf.ones([1, length], dtype="int32")
y = layer(idx)
self.assertEqual(y.shape, (1, length, hidden_size,))
self.assertEqual(y.shape, (
1,
length,
hidden_size,
))
x = tf.ones([1, length, hidden_size])
output = layer(x, "linear")
self.assertEqual(output.shape, (1, length, vocab_size,))
self.assertEqual(output.shape, (
1,
length,
vocab_size,
))
def test_feed_forward_network(self):
hidden_size = 64
filter_size = 32
relu_dropout = 0.5
layer = ffn_layer.FeedForwardNetwork(hidden_size, filter_size, relu_dropout)
self.assertDictEqual(layer.get_config(), {
"hidden_size": hidden_size,
"filter_size": filter_size,
"relu_dropout": relu_dropout,
})
self.assertDictEqual(
layer.get_config(), {
"hidden_size": hidden_size,
"filter_size": filter_size,
"relu_dropout": relu_dropout,
})
length = 2
x = tf.ones([1, length, hidden_size])
y = layer(x, training=True)
self.assertEqual(y.shape, (1, length, hidden_size,))
self.assertEqual(y.shape, (
1,
length,
hidden_size,
))
def test_metric_layer(self):
vocab_size = 50
......@@ -90,7 +118,11 @@ class TransformerLayersTest(tf.test.TestCase):
name="logits")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
output_logits = metrics.MetricLayer(vocab_size)([logits, targets])
self.assertEqual(output_logits.shape.as_list(), [None, None, vocab_size,])
self.assertEqual(output_logits.shape.as_list(), [
None,
None,
vocab_size,
])
if __name__ == "__main__":
......
......@@ -24,6 +24,7 @@ from __future__ import print_function
import os
import tempfile
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
from absl import logging
import numpy as np
import tensorflow as tf
......
......@@ -22,6 +22,7 @@ import collections
import re
import sys
import unicodedata
from absl import logging
import numpy as np
......@@ -29,7 +30,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: disable=g-complex-comprehension
PAD = "<pad>"
PAD_ID = 0
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Utilities for pre-processing classification data."""
from absl import logging
from official.nlp.xlnet import data_utils
......
......@@ -22,12 +22,12 @@ from __future__ import print_function
import collections
import json
import os
from absl import logging
import numpy as np
import tensorflow as tf
special_symbols = {
"<unk>": 0,
"<s>": 1,
......@@ -51,10 +51,10 @@ SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
"min_num_words"])
"min_num_words"
])
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
......@@ -253,20 +253,14 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
"""Turn beg and end indices into actual mask."""
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
all_indices = tf.where(
non_func_mask,
tf.range(tgt_len, dtype=tf.int64),
tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
all_indices = tf.where(non_func_mask, tf.range(tgt_len, dtype=tf.int64),
tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
candidate_matrix = tf.cast(
tf.logical_and(
all_indices[None, :] >= beg_indices[:, None],
all_indices[None, :] < end_indices[:, None]),
tf.float32)
tf.logical_and(all_indices[None, :] >= beg_indices[:, None],
all_indices[None, :] < end_indices[:, None]), tf.float32)
cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])),
[-1, tgt_len])
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, tgt_len])
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
is_masked = tf.cast(target_mask, tf.bool)
......@@ -274,8 +268,8 @@ def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
return is_masked, target_mask
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
max_num_words, boundary):
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, max_num_words,
boundary):
"""Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio
mask_alpha = tgt_len / num_predict / 1.2
......@@ -283,7 +277,7 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
......@@ -302,8 +296,8 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
beg_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
......@@ -333,7 +327,7 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
......@@ -353,8 +347,8 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
beg_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
......@@ -387,8 +381,7 @@ def _single_token_mask(inputs, tgt_len, num_predict):
"""Sample individual tokens as prediction targets."""
all_indices = tf.range(tgt_len, dtype=tf.int64)
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices)
......@@ -404,7 +397,10 @@ def _single_token_mask(inputs, tgt_len, num_predict):
return is_masked, target_mask
def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
def _online_sample_masks(inputs,
tgt_len,
num_predict,
online_masking_config,
boundary=None):
"""Sample target positions to predict."""
logging.info("Online sample with strategy: `%s`.",
......@@ -422,8 +418,7 @@ def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
assert boundary is not None, "word span sampling requires `boundary`"
return _word_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_words,
online_masking_config.max_num_words,
boundary)
online_masking_config.max_num_words, boundary)
else:
raise NotImplementedError
......@@ -529,10 +524,11 @@ def create_pretrain_dataset(file_names,
example["target"] = tf.reshape(target, [num_predict])
##### target mask
target_mask = tf.concat(
[tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)],
axis=0)
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)
],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict])
else:
example["target"] = tf.reshape(target, [seq_len])
......@@ -562,7 +558,11 @@ def create_pretrain_dataset(file_names,
return dataset
def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
def format_filename(prefix,
suffix,
bsz_per_host,
seq_len,
reuse_len=None,
uncased=False):
"""Generates input file name pattern."""
if reuse_len is not None and reuse_len > 0:
......@@ -577,8 +577,8 @@ def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
else:
case_str = "uncased."
file_name = "{}.seq-{}.{}{}{}{}".format(
prefix, seq_len, reuse_str, bsz_str, case_str, suffix)
file_name = "{}.seq-{}.{}{}{}{}".format(prefix, seq_len, reuse_str, bsz_str,
case_str, suffix)
return file_name
......@@ -722,9 +722,7 @@ def parse_files_to_dataset(parser,
# even more randomness to the training pipeline.
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=True,
cycle_length=cycle_length))
tf.data.TFRecordDataset, sloppy=True, cycle_length=cycle_length))
buffer_size = 2048
logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
......@@ -778,9 +776,8 @@ def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
index = tf.reshape(tf.transpose(index), [-1])
# non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
non_func_tokens = tf.logical_not(
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
masked_tokens = tf.logical_and(is_masked, non_func_tokens)
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
......
......@@ -21,6 +21,7 @@ import collections
import csv
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -23,6 +23,7 @@ import json
import os
import random
# Import libraries
from absl import app
from absl import flags
import absl.logging as _logging # pylint: disable=unused-import
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import random
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -21,7 +21,6 @@ import unicodedata
import six
SPIECE_UNDERLINE = '▁'
......@@ -95,8 +94,8 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
......
......@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import functools
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -22,6 +22,7 @@ from __future__ import print_function
import functools
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -24,6 +24,7 @@ import json
import os
import pickle
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -38,12 +38,13 @@ def create_run_config(is_training, is_finetune, flags):
clamp_len=flags.clamp_len)
if not is_finetune:
kwargs.update(dict(
mem_len=flags.mem_len,
reuse_len=flags.reuse_len,
bi_data=flags.bi_data,
clamp_len=flags.clamp_len,
same_length=flags.same_length))
kwargs.update(
dict(
mem_len=flags.mem_len,
reuse_len=flags.reuse_len,
bi_data=flags.bi_data,
clamp_len=flags.clamp_len,
same_length=flags.same_length))
return RunConfig(**kwargs)
......@@ -80,8 +81,10 @@ class XLNetConfig(object):
assert FLAGS is not None or json_path is not None or args_dict is not None
self.keys = ['n_layer', 'd_model', 'n_head', 'd_head', 'd_inner',
'ff_activation', 'untie_r', 'n_token']
self.keys = [
'n_layer', 'd_model', 'n_head', 'd_head', 'd_inner', 'ff_activation',
'untie_r', 'n_token'
]
if FLAGS is not None:
self.init_from_flags(FLAGS)
......@@ -152,17 +155,17 @@ class RunConfig(object):
init_method: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
init_std: float, initialize the parameters with a normal distribution with
mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length
for each token.
reuse_len: int, the number of tokens in the currect batch to be cached and
reused in the future.
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len. -1
means no clamping.
same_length: bool, whether to use the same attention length for each
token.
use_cls_mask: bool, whether to introduce cls mask.
"""
......
......@@ -48,5 +48,6 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
logging.info(pos_emb)
self.assertAllClose(pos_emb, target)
if __name__ == "__main__":
tf.test.main()
......@@ -65,7 +65,7 @@ CACHE_INVALIDATION_SEC = 3600 * 24
# == Data Generation ===========================================================
# ==============================================================================
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
# of the main training loop.
# of the main training loop.
# Number of batches to run per epoch when using synthetic data. At high batch
# sizes, we run for more batches than with real data, which is good since
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import json
# pylint: disable=g-bad-import-order
# Import libraries
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment