Commit 6ff62233 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404558260
parent 56d36d5c
......@@ -23,3 +23,12 @@ respectively.
* [`DualEncoder`](dual_encoder.py) implements a dual encoder model, suitbale for
retrieval tasks.
* [`Seq2SeqTransformer`](seq2seq_transformer.py) implements the original
Transformer model for seq-to-seq tasks.
* [`T5Transformer`](t5.py) implements a standalone T5 model for seq-to-seq
tasks. The models are compatible with released T5 architecture and converted
checkpoints. The modules are implemented as `tf.Module`. To use with Keras,
users can wrap them within Keras customized layers, i.e. we can define the
modules inside the `__init__` of Keras layer and call the modules in `call`.
......@@ -24,6 +24,8 @@ from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifi
from official.nlp.modeling.models.dual_encoder import DualEncoder
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
from official.nlp.modeling.models.seq2seq_transformer import *
from official.nlp.modeling.models.t5 import T5Transformer
from official.nlp.modeling.models.t5 import T5TransformerParams
from official.nlp.modeling.models.xlnet import XLNetClassifier
from official.nlp.modeling.models.xlnet import XLNetPretrainer
from official.nlp.modeling.models.xlnet import XLNetSpanLabeler
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement T5 Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1910.10683.pdf
T5TransformerParams and T5Transformer are public interfaces.
Other modules are implementation details, so users should never build libraries
depending on them.
To use with Keras, users can wrap them within Keras customized layers.
"""
import dataclasses
import functools
import math
from typing import Callable, Dict, Optional, Sequence, Text, Union
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
ShapeLike = Union[int, Sequence[int], tf.TensorShape]
Initializer = Callable[..., tf.Tensor]
class Module(tf.Module):
"""The nn Module extends from the tf.Module."""
def __init__(self, dtype: tf.DType = tf.float32, name: Optional[Text] = None):
"""Initializes the nn Module.
Args:
dtype: the variable allocation dtype.
name: a string for the module name.
"""
super().__init__(name=name)
self.dtype = dtype
def create_variable(self,
name: Text,
shape: ShapeLike,
initializer: Initializer,
dtype: tf.DType = tf.float32,
**kwargs):
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self,
variable: tf.Variable,
as_dtype: Optional[tf.DType] = None):
if as_dtype is not None:
variable = tf.cast(variable, dtype=as_dtype)
return variable
@tf.custom_gradient
def dense_gradient(x: tf.Tensor):
"""Identity operation whose gradient is converted to a ``tf.Tensor``.
>>> embedding = tf.Variable(tf.random.normal([3, 3]))
>>> with tf.GradientTape() as tape:
... y = tf.nn.embedding_lookup(dense_gradient(embedding), [1])
>>> tape.gradient(y, embedding).numpy()
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 0., 0., 0.]], dtype=float32)
Args:
x: A ``tf.Tensor``.
Returns:
The input ``tf.Tensor`` and a dense identity gradient function.
"""
def grad(dy):
if isinstance(dy, tf.IndexedSlices):
return tf.convert_to_tensor(dy)
else:
return dy
return x, grad
def make_attention_mask(query_input,
key_input,
pairwise_fn=tf.multiply,
dtype=tf.float32):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
attention weights will be `[batch..., heads, len_q, len_kv]` and this
function will produce `[batch..., 1, len_q, len_kv]`.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
dtype: mask return dtype
Returns:
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
"""
mask = pairwise_fn(
tf.expand_dims(query_input, axis=-1), tf.expand_dims(key_input, axis=-2))
mask = tf.expand_dims(mask, axis=-3)
return tf.cast(mask, dtype=dtype)
def make_causal_mask(x, dtype=tf.float32):
"""Make a causal mask for self-attention.
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
will be `[batch..., heads, len, len]` and this function will produce a
causal mask of shape `[batch..., 1, len, len]`.
Args:
x: input array of shape `[batch..., len]`
dtype: mask return dtype
Returns:
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
"""
x_shape = tf.shape(x)
idxs = tf.broadcast_to(tf.range(x_shape[-1], dtype=tf.int32), x_shape)
return make_attention_mask(idxs, idxs, tf.greater_equal, dtype=dtype)
class Embed(Module):
"""Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
"""
def __init__(self,
vocab_size: int,
features: int,
embeddings_initializer: Optional[Initializer] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.features = features
self.compute_dtype = compute_dtype
if embeddings_initializer:
self.embed_init = embeddings_initializer
else:
self.embed_init = tf.keras.initializers.TruncatedNormal(stddev=1.0)
with self.name_scope:
self.embeddings = self.create_variable(
"embedding", [self.vocab_size, self.features],
self.embed_init,
dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor, one_hot: bool = True):
"""Embeds the inputs along the last dimension.
Args:
inputs: input data, the last dimension is to embed.
one_hot: whether to use one-hot matmul to gather embeddings.
Returns:
The output shape follows the input, with an additional `features`
dimension appended.
"""
if one_hot:
flat_inputs = tf.reshape(inputs, [-1])
one_hot_data = tf.one_hot(
flat_inputs, depth=self.vocab_size, dtype=self.compute_dtype)
embeddings = tf.matmul(
one_hot_data,
self.read_variable(self.embeddings, as_dtype=self.compute_dtype))
input_shape = tf_utils.get_shape_list(inputs)
embeddings = tf.reshape(embeddings, input_shape + [self.features])
return embeddings
else:
return tf.nn.embedding_lookup(
dense_gradient(
self.read_variable(self.embeddings, as_dtype=self.compute_dtype)),
inputs)
def attend(self, query):
"""Attends over the embedding using a query tensor.
Args:
query: array with last dimension equal the feature depth `features` of the
embedding.
Returns:
An tensor with final dim `num_embeddings` corresponding to the batched
inner-product of the array of query vectors against each embedding.
Commonly used for weight-sharing between embeddings and logit transform
in NLP models.
"""
return tf.matmul(
query,
self.read_variable(self.embeddings, as_dtype=query.dtype),
transpose_b=True)
class RMSNorm(Module):
"""A layernorm module in the T5 style.
No bias and no subtraction of mean.
"""
def __init__(self, hidden_size: int, epsilon: float = 1e-6, **kwargs):
super().__init__(**kwargs)
self.variance_epsilon = epsilon
with self.name_scope:
self.weight = self.create_variable(
"scale", [hidden_size],
dtype=self.dtype,
initializer=tf.keras.initializers.Ones())
@tf.Module.with_name_scope
def __call__(self, x):
# Keeps the computation inside the layer norm to be float32.
compute_dtype = x.dtype
x = tf.cast(x, dtype=tf.float32)
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
x = tf.cast(x, dtype=compute_dtype)
return self.read_variable(self.weight, as_dtype=compute_dtype) * x
class Linear(Module):
"""Linear module, optionally including bias."""
def __init__(self,
in_features: int,
out_features: int,
use_bias: bool = True,
w_init: Optional[Initializer] = None,
b_init: Optional[Initializer] = None,
**kwargs):
"""Constructs a `Linear` module."""
super().__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.w_init = w_init
if self.use_bias:
self.b_init = b_init if b_init else tf.keras.initializers.Zeros()
elif b_init is not None:
raise ValueError("When not using a bias the b_init must be None.")
with self.name_scope:
if self.w_init is None:
stddev = 1 / math.sqrt(self.in_features)
self.w_init = tf.keras.initializers.HeNormal()
self.w = self.create_variable(
"kernel", [self.in_features, self.out_features],
initializer=self.w_init,
dtype=self.dtype)
if self.use_bias:
self.b = self.create_variable(
"bias", [self.out_features],
initializer=self.b_init,
dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
outputs = tf.matmul(inputs,
self.read_variable(self.w, as_dtype=inputs.dtype))
if self.use_bias:
outputs = tf.add(outputs,
self.read_variable(self.b, as_dtype=inputs.dtype))
return outputs
class Linear3D(Module):
"""Linear3D module, optionally including bias.
Kernel stored as 2d parameter for compatibility with Adafactor optimizer.
"""
def __init__(self,
in_features: int,
out_features: int,
num_heads: int,
use_bias: bool = True,
to_3d: bool = True,
w_init: Optional[Initializer] = None,
b_init: Optional[Initializer] = None,
**kwargs):
"""Constructs a `Linear3D` module."""
super().__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.use_bias = use_bias
self.to_3d = to_3d
self.w_init = w_init
if self.to_3d:
self.kernel_2d_shape = (self.in_features,
self.num_heads * self.out_features)
self.kernel_3d_shape = (self.in_features, self.num_heads,
self.out_features)
self.bias_shape = (self.num_heads, self.out_features)
bias_rank = 2
else:
self.kernel_2d_shape = (self.in_features * self.num_heads,
self.out_features)
self.kernel_3d_shape = (self.num_heads, self.in_features,
self.out_features)
self.bias_shape = (self.out_features,)
bias_rank = 1
if self.use_bias:
self.b_init = b_init or tf.keras.initializers.Zeros()
elif b_init is not None:
raise ValueError("When not using a bias the b_init must be None.")
with self.name_scope:
if self.w_init is None:
self.w_init = tf.keras.initializers.HeNormal()
self.w = self.create_variable(
"kernel",
self.kernel_2d_shape,
initializer=self.w_init,
dtype=self.dtype)
if self.use_bias:
self.b = self.create_variable(
"bias", self.bias_shape, initializer=self.b_init, dtype=self.dtype)
@tf.Module.with_name_scope
def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
# B: batch size
# S: From Sequence length
# D: dimension
# N: Number of heads
# H: head size
compute_dtype = inputs.dtype
w = self.read_variable(self.w, as_dtype=compute_dtype)
w = tf.reshape(w, self.kernel_3d_shape)
if self.to_3d:
outputs = tf.einsum("BSD,DNH->BSNH", inputs, w)
else:
outputs = tf.einsum("BSNH,NHD->BSD", inputs, w)
if self.use_bias:
outputs = tf.add(outputs,
self.read_variable(self.b, as_dtype=compute_dtype))
return outputs
class Dropout(Module):
"""Randomly drop units in the input at a given rate."""
def __init__(self, rate: float, **kwargs):
"""Constructs a Dropout module.
Args:
rate: Probability that each element of x is discarded. Must be a scalar in
the range `[0, 1)`.
**kwargs: other keyword args.
"""
super().__init__(**kwargs)
self._rate = rate
@tf.Module.with_name_scope
def __call__(self,
x: tf.Tensor,
training: bool,
noise_shape: Optional[ShapeLike] = None) -> tf.Tensor:
"""call method for the Dropout module.
Args:
x: the input tensor.
training: whether it is performing training pass.
noise_shape: (Optional) Shape vector controlling the shape of the random
noise used to apply dropout. If not set this will be the shape of the
input. If set it should be broadcastable to the input shape.
Returns:
A tensor after applying dropout.
"""
if not training:
return x
return tf.nn.dropout(x, rate=self._rate, noise_shape=noise_shape)
class FFN(Module):
"""Feed-forward Network. No layer norm, output dropout, or skip connection."""
activation_map = {
"relu": tf.nn.relu,
"gelu": functools.partial(tf.nn.gelu, approximate=True),
"swish": tf.nn.silu,
"silu": tf.nn.silu,
}
def __init__(self,
d_model: int,
d_ff: int,
activations: Sequence[str],
use_bias: bool = False,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
self.use_bias = use_bias
with self.name_scope:
self.wi = []
self.activations = activations
for idx, act_fn in enumerate(activations):
if (act_fn is not None and act_fn != "linear" and
act_fn not in self.activation_map):
raise ValueError("Invalid activation function string is passed: %s" %
act_fn)
dense_name = "wi" if len(activations) == 1 else f"wi_{idx}"
self.wi.append(
Linear(
d_model,
d_ff,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name=dense_name))
self.wo = Linear(
d_ff,
d_model,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="wo")
self.dropout = Dropout(rate=dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states: tf.Tensor,
training: bool = False) -> tf.Tensor:
h = hidden_states
factors = []
for wi, act_fn in zip(self.wi, self.activations):
if act_fn is None or act_fn == "linear":
factors.append(wi(h))
else:
factors.append(self.activation_map[act_fn](wi(h)))
h = functools.reduce(tf.math.multiply, factors)
h_shape = tf_utils.get_shape_list(h)
h_shape[-2] = 1
h = self.dropout(h, noise_shape=h_shape, training=training)
h = self.wo(h)
return h
class RelativePositionEmbedding(Module):
"""Relative position embeddings of T5 style."""
def __init__(self,
num_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
bidirectional: bool = True,
embeddings_initializer: Optional[Initializer] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.bidirectional = bidirectional
self.relative_attention_max_distance = relative_attention_max_distance
with self.name_scope:
self.relative_attention_bias = Embed(
vocab_size=self.relative_attention_num_buckets,
features=self.num_heads,
embeddings_initializer=embeddings_initializer,
dtype=self.dtype,
compute_dtype=compute_dtype,
name="rel_embedding")
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
else:
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(
tf.cast(n, tf.float32) / max_exact + np.finfo(np.float32).eps) /
math.log(max_distance / max_exact) * (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
@tf.Module.with_name_scope
def __call__(self, qlen, klen):
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance)
values = self.relative_attention_bias(rp_bucket)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]),
axis=0) # shape (1, num_heads, qlen, klen)
return values
class MultiHeadAttention(Module):
"""T5 Attention from Mesh TensorFlow."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
use_bias: bool = False,
dropout_rate: Optional[float] = 0.0,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.d_model = d_model
self.d_kv = d_kv
self.num_heads = num_heads
self.rescale_query = rescale_query
self.use_bias = use_bias
if rescale_query or weight_initializer is None:
query_w_init = weight_initializer
else:
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling))
self.q = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=query_w_init,
b_init=bias_initializer,
dtype=self.dtype,
name="q")
self.k = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="k")
self.v = Linear3D(
self.d_model,
self.d_kv,
num_heads=self.num_heads,
use_bias=self.use_bias,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="v")
self.o = Linear3D(
self.d_kv,
self.d_model,
num_heads=self.num_heads,
use_bias=self.use_bias,
to_3d=False,
w_init=weight_initializer,
b_init=bias_initializer,
dtype=self.dtype,
name="o")
self.dropout = Dropout(dropout_rate)
def _update_cache(self, key, value, cache, decode_position):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
# TPU one-hot handling.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_position, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1])
key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_position, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1])
value = cache["value"] + value * indices
# Update cache
cache["key"] = key
cache["value"] = value
return key, value
@tf.Module.with_name_scope
def __call__(self,
query,
mask=None,
kv=None,
position_bias=None,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_position=None,
training=False):
"""MultiHeadAttention at work.
Args:
query: Tensor of shape (bs, qlen, d_model).
mask: None or Tensor of shape (bs, n_heads, qlen, klen).
kv: None or Tensor of shape (bs, klen, d_model).
position_bias: None or Tensor of shape (bs, n_heads, qlen, klen).
cache: If not None, cache["key"] and cache["value"] are Tensors of shape
(bs, klen, n_heads, d_kv).
decode_position: If not None, which position of the sequence we are
decoding for. Ranges from 0 to klen - 1.
training: Effects the behavior of dropout.
Returns:
A dictionary, output["context"] is the output after attention,
output["cache"] contains updated cache for the next round of
autoregressive decoding.
"""
# Input is (bs, qlen, d_model)
use_cache = cache is not None
if kv is None:
kv = query
q = self.q(query)
if self.rescale_query:
q /= tf.math.sqrt(tf.cast(self.d_kv, dtype=q.dtype))
k = self.k(kv)
v = self.v(kv)
if use_cache:
k, v = self._update_cache(k, v, cache, decode_position)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(q_dim)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
scores = tf.einsum("bqnd,bknd->bnqk", q, k) # (bs, n_heads, qlen, klen)
if position_bias is not None:
# If position_bias is None, the input embedings should already include
# position embeddings.
if use_cache:
bias_shape = position_bias.shape.as_list()
position_bias = tf.slice(
position_bias, [0, 0, decode_position, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
scores += position_bias
if mask is not None:
scores += mask # (bs, n_heads, qlen, klen)
weights = tf.nn.softmax(tf.cast(scores, tf.float32), axis=-1)
# weights shape = (bs, n_heads, qlen, klen)
weights = tf.cast(weights, scores.dtype)
weight_shape = tf_utils.get_shape_list(weights)
# NOTE: T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to. We assume it is the query dimension.
# (bs, n_heads, qlen, klen)
weight_shape[-2] = 1
weights = self.dropout(weights, training=training, noise_shape=weight_shape)
c = tf.einsum("bnqk,bknd->bqnd", weights, v)
c = self.o(c)
outputs = dict(context=c)
if cache:
outputs["cache"] = cache
return outputs
class SelfAttention(Module):
"""Self attention block including residual connection."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = MultiHeadAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="attention")
self.layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="layer_norm")
self.dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
attention_mask=None,
position_bias=None,
cache=None,
decode_position=None,
training=False):
norm_x = self.layer_norm(hidden_states)
attention_outputs = self.self_attention(
query=norm_x,
mask=attention_mask,
position_bias=position_bias,
cache=cache,
decode_position=decode_position,
training=training)
y = attention_outputs.pop("context")
tensor_shape = tf_utils.get_shape_list(y)
tensor_shape[-2] = 1
y = self.dropout(y, noise_shape=tensor_shape, training=training)
layer_output = hidden_states + y
attention_outputs["layer_output"] = layer_output
return attention_outputs
class CrossAttention(Module):
"""Cross attention block including residual connection."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.cross_attention = MultiHeadAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="attention")
self.layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="layer_norm")
self.dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
kv,
attention_mask=None,
position_bias=None,
cache=None,
training=False):
norm_x = self.layer_norm(hidden_states)
attention_outputs = self.cross_attention(
query=norm_x,
kv=kv,
mask=attention_mask,
position_bias=position_bias,
cache=cache,
training=training)
y = attention_outputs.pop("context")
tensor_shape = tf_utils.get_shape_list(y)
tensor_shape[-2] = 1
y = self.dropout(y, noise_shape=tensor_shape, training=training)
layer_output = hidden_states + y
attention_outputs["layer_output"] = layer_output
return attention_outputs
class EncoderBlock(Module):
"""Transformer Encoder Block with only self attention."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
d_ff: int,
ffn_activations: Sequence[str] = ("relu",),
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = SelfAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="self_attention")
self.ffn_layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="ffn_layer_norm")
self.ffn = FFN(
d_model=d_model,
d_ff=d_ff,
dropout_rate=dropout_rate,
activations=ffn_activations,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="ffn")
self.ffn_output_dropout = Dropout(dropout_rate)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
attention_mask=None,
position_bias=None,
training=False):
attention_outputs = self.self_attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
training=training)
attn_output = attention_outputs["layer_output"]
ffn_output = self.ffn_layer_norm(attn_output)
ffn_output = self.ffn(ffn_output, training=training)
tensor_shape = tf_utils.get_shape_list(ffn_output)
tensor_shape[-2] = 1
ffn_output = self.ffn_output_dropout(
ffn_output, noise_shape=tensor_shape, training=training)
ffn_output = attn_output + ffn_output
return ffn_output
class EncDecoderBlock(Module):
"""Transformer Decoder Block with enc-decoder cross attention."""
def __init__(self,
d_model: int,
d_kv: int,
num_heads: int,
d_ff: int,
ffn_activations: Sequence[str] = ("relu",),
dropout_rate: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-6,
rescale_query: bool = False,
weight_initializer: Optional[Initializer] = None,
bias_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
with self.name_scope:
self.self_attention = SelfAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="self_attention")
self.cross_attention = CrossAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
rescale_query=rescale_query,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="cross_attention")
self.ffn_layer_norm = RMSNorm(
hidden_size=d_model,
epsilon=layer_norm_epsilon,
dtype=self.dtype,
name="ffn_layer_norm")
self.ffn = FFN(
d_model=d_model,
d_ff=d_ff,
dropout_rate=dropout_rate,
activations=ffn_activations,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=self.dtype,
name="ffn")
self.ffn_output_dropout = Dropout(dropout_rate,)
@tf.Module.with_name_scope
def __call__(self,
hidden_states,
encoder_hidden_states,
attention_mask=None,
encoder_decoder_mask=None,
position_bias=None,
cache=None,
decode_position=None,
training=False):
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask=attention_mask,
decode_position=decode_position,
position_bias=position_bias,
cache=cache,
training=training)
if "cache" in self_attention_outputs:
cache = self_attention_outputs["cache"]
# No relative position bias is used for encoder-decoder cross attention.
cross_attention_outputs = self.cross_attention(
self_attention_outputs["layer_output"],
kv=encoder_hidden_states,
attention_mask=encoder_decoder_mask,
training=training)
attn_output = cross_attention_outputs["layer_output"]
ffn_output = self.ffn_layer_norm(attn_output)
ffn_output = self.ffn(ffn_output, training=training)
tensor_shape = tf_utils.get_shape_list(ffn_output)
tensor_shape[-2] = 1
ffn_output = self.ffn_output_dropout(
ffn_output, noise_shape=tensor_shape, training=training)
ffn_output = attn_output + ffn_output
return ffn_output, cache
@dataclasses.dataclass
class T5TransformerParams:
"""Transformer parameters."""
num_layers: int
d_model: int
d_kv: int
num_heads: int
d_ff: int
vocab_size: int
dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-6
shared_embedding: bool = False
vocab_embeddings_initializer: Optional[Initializer] = None
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
relative_embeddings_initializer: Optional[Initializer] = None
weight_initializer: Optional[Initializer] = (tf.keras.initializers.HeNormal())
bias_initializer: Optional[Initializer] = None
rescale_query: bool = False
bidirectional: bool = True
ffn_activations: Sequence[str] = ("relu",)
logits_via_embedding: bool = True
num_decoder_layers: Optional[int] = None
one_hot_embedding: bool = True
layer_sharing: bool = False
class Encoder(Module):
"""Transformer Model Encoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
shared_embedding: Optional[tf.Variable] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.config = config
self.compute_dtype = compute_dtype
self.embed_dim = config.d_model
with self.name_scope:
# Input Embedding.
if shared_embedding is None:
self.input_embed = Embed(
vocab_size=self.config.vocab_size,
features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="input_embedding")
else:
self.input_embed = shared_embedding
# Creates an alias to the input embed for encoder-only models.
self.word_embed = self.input_embed
self.relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="relative_posemb")
self.input_dropout = Dropout(self.config.dropout_rate,)
self.encoder_layers = []
for layer_idx in range(self.config.num_layers):
if self.config.layer_sharing and layer_idx > 0:
self.encoder_layers.append(self.encoder_layers[0])
else:
self.encoder_layers.append(
EncoderBlock(
d_model=self.config.d_model,
d_kv=self.config.d_kv,
num_heads=self.config.num_heads,
d_ff=self.config.d_ff,
dropout_rate=self.config.dropout_rate,
ffn_activations=self.config.ffn_activations,
rescale_query=self.config.rescale_query,
weight_initializer=self.config.weight_initializer,
bias_initializer=self.config.bias_initializer,
dtype=self.dtype,
name="encoder_block_%d" % layer_idx))
self.output_norm = RMSNorm(
hidden_size=self.config.d_model,
epsilon=self.config.layer_norm_epsilon,
dtype=self.dtype,
name="final_layer_norm")
self.output_dropout = Dropout(self.config.dropout_rate,)
@tf.Module.with_name_scope
def __call__(self, inputs, encoder_mask=None, training=False):
"""Applies Transformer model on the inputs.
Args:
inputs: input data
encoder_mask: the encoder self-attention mask.
training: whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
"""
# Casts inputs to the dtype.
if encoder_mask is not None:
encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
input_length = tf_utils.get_shape_list(inputs)[1]
position_bias = self.relative_embedding(input_length, input_length)
for i in range(cfg.num_layers):
x = self.encoder_layers[i](
x,
attention_mask=encoder_mask,
position_bias=position_bias,
training=training)
encoded = self.output_norm(x)
encoded = self.output_dropout(encoded, training=training)
return encoded
class Decoder(Module):
"""Transformer Model Decoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
shared_embedding: Optional[tf.Variable] = None,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
self.config = config
self.compute_dtype = compute_dtype
if self.config.num_decoder_layers is None:
self.config.num_decoder_layers = self.config.num_layers
with self.name_scope:
# Target Embedding.
if shared_embedding is None:
self.target_embed = Embed(
vocab_size=self.config.vocab_size,
features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="target_embedding")
else:
self.target_embed = shared_embedding
self.target_dropout = Dropout(self.config.dropout_rate,)
# Position bias for the target self attention.
self.relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="relative_posemb")
self.decoder_layers = []
for layer_idx in range(self.config.num_decoder_layers):
if self.config.layer_sharing and layer_idx > 0:
self.decoder_layers.append(self.decoder_layers[0])
else:
self.decoder_layers.append(
EncDecoderBlock(
d_model=self.config.d_model,
d_kv=self.config.d_kv,
num_heads=self.config.num_heads,
d_ff=self.config.d_ff,
dropout_rate=self.config.dropout_rate,
ffn_activations=self.config.ffn_activations,
rescale_query=self.config.rescale_query,
weight_initializer=self.config.weight_initializer,
bias_initializer=self.config.bias_initializer,
dtype=self.dtype,
name="decoder_block_%d" % layer_idx))
self.output_norm = RMSNorm(
hidden_size=self.config.d_model,
epsilon=self.config.layer_norm_epsilon,
dtype=self.dtype,
name="final_layer_norm")
self.output_dropout = Dropout(self.config.dropout_rate,)
if not self.config.logits_via_embedding:
self.logits_dense = Linear(
in_features=self.config.d_model,
out_features=self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
name="logits")
@tf.Module.with_name_scope
def __call__(self,
decoder_input_tokens,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
decode=False,
decode_position=None,
cache=None,
max_decode_len=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
decoder_input_tokens: the decoder input tokens.
encoded: the encoder outputs.
decoder_mask: the decoder self-attention mask.
encoder_decoder_mask: the cross-attention mask.
decode: Whether to perform autoaggressive decoding.
decode_position: integer, the position to decode.
cache: The cache dictionary of key, value tensors.
max_decode_len: An optional integer specifying the maximum decoding
length. Note that this is only used for defining the relative position
embedding parameters.
training: Whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
"""
cfg = self.config
# Casts inputs to the dtype.
encoded = tf.cast(encoded, self.compute_dtype)
if decoder_mask is not None:
decoder_mask = tf.cast(decoder_mask, self.compute_dtype)
if encoder_decoder_mask is not None:
encoder_decoder_mask = tf.cast(encoder_decoder_mask, self.compute_dtype)
x = self.target_embed(decoder_input_tokens, one_hot=cfg.one_hot_embedding)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.target_dropout(x, noise_shape=tensor_shape, training=training)
if cache is not None:
position_bias = self.relative_embedding(max_decode_len, max_decode_len)
else:
input_length = tf_utils.get_shape_list(decoder_input_tokens)[1]
position_bias = self.relative_embedding(input_length, input_length)
for i in range(cfg.num_decoder_layers):
if cache is None:
x, _ = self.decoder_layers[i](
x,
encoder_hidden_states=encoded,
attention_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias,
training=training)
else:
x, cache[i] = self.decoder_layers[i](
x,
encoder_hidden_states=encoded,
attention_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias,
decode_position=decode_position,
cache=cache[i],
training=training)
output = self.output_norm(x)
tensor_shape = tf_utils.get_shape_list(output)
tensor_shape[-2] = 1
output = self.target_dropout(
output, noise_shape=tensor_shape, training=training)
if self.config.logits_via_embedding:
logits = self.target_embed.attend(output)
logits = logits / math.sqrt(cfg.d_model)
else:
logits = self.logits_dense(output)
return logits, cache
class T5Transformer(Module):
"""Transformer Encoder+Decoder for sequence to sequence."""
def __init__(self,
config: T5TransformerParams,
compute_dtype: tf.DType = tf.float32,
**kwargs):
super().__init__(**kwargs)
# Builds the model components.
shared_embedding = config.shared_embedding
self.compute_dtype = compute_dtype
self.decoder_cfg = dataclasses.replace(config, bidirectional=False)
if self.decoder_cfg.num_decoder_layers is None:
self.decoder_cfg.num_decoder_layers = self.decoder_cfg.num_layers
self.encoder_cfg = dataclasses.replace(config, bidirectional=True)
with self.name_scope:
if shared_embedding:
self.shared_embedding = Embed(
vocab_size=config.vocab_size,
features=config.d_model,
embeddings_initializer=config.vocab_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name="shared")
else:
self.shared_embedding = None
self.encoder = Encoder(
self.encoder_cfg,
self.shared_embedding,
dtype=self.dtype,
compute_dtype=self.compute_dtype)
self.decoder = Decoder(
self.decoder_cfg,
self.shared_embedding,
dtype=self.dtype,
compute_dtype=self.compute_dtype)
def encode(self,
encoder_input_tokens,
encoder_segment_ids=None,
training=False):
eligible_positions = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool)
if encoder_segment_ids is not None:
segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
encoder_mask = (1.0 - tf.cast(encoder_mask, self.compute_dtype)) * -1e9
return self.encoder(encoder_input_tokens, encoder_mask, training=training)
def decode(
self,
encoded,
decoder_target_tokens,
encoder_input_tokens, # only used for masks
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
decode_position=None,
cache=None,
max_decode_len=None,
decode=False,
training=False):
if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens
# fast autoregressive decoding uses only a special encoder-decoder mask
decoder_mask = None
encoder_decoder_mask = make_attention_mask(
tf.cast(
tf.not_equal(tf.ones_like(decoder_target_tokens), 0),
self.compute_dtype),
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype),
dtype=tf.bool)
else:
# Note that, masks should be created using decoder_target_tokens.
eligible_targets = tf.cast(
tf.not_equal(decoder_target_tokens, 0), self.compute_dtype)
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
decoder_mask = tf.math.logical_and(
make_attention_mask(
eligible_targets, eligible_targets, dtype=tf.bool),
make_causal_mask(decoder_target_tokens, dtype=tf.bool))
encoder_decoder_mask = make_attention_mask(
eligible_targets, eligible_inputs, dtype=tf.bool)
if encoder_segment_ids is not None:
if decoder_mask is not None:
decoder_mask = tf.math.logical_and(
decoder_mask,
make_attention_mask(
decoder_segment_ids,
decoder_segment_ids,
tf.equal,
dtype=tf.bool))
encoder_decoder_mask = tf.math.logical_and(
encoder_decoder_mask,
make_attention_mask(
decoder_segment_ids,
encoder_segment_ids,
tf.equal,
dtype=tf.bool))
if decoder_mask is not None:
decoder_mask = (1.0 - tf.cast(decoder_mask, self.compute_dtype)) * -1e9
encoder_decoder_mask = (
1.0 - tf.cast(encoder_decoder_mask, self.compute_dtype)) * -1e9
logits, cache = self.decoder(
decoder_input_tokens,
encoded,
decode_position=decode_position,
decoder_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
cache=cache,
max_decode_len=max_decode_len,
decode=decode,
training=training)
return dict(logits=logits, encoded=encoded, cache=cache)
@tf.Module.with_name_scope
def __call__(self,
encoder_input_tokens,
decoder_target_tokens,
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
encoder_input_tokens: input tokens to the encoder.
decoder_target_tokens: target tokens to the decoder.
decoder_input_tokens: input tokens to the decoder, only required for
training.
encoder_segment_ids: input segmentation info for packed examples.
decoder_segment_ids: target segmentation info for packed examples.
training: whether it is training pass, affecting dropouts.
Returns:
a dictionary of logits/cache.
"""
encoded = self.encode(
encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
training=training)
outputs = self.decode(
encoded=encoded,
decoder_target_tokens=decoder_target_tokens,
encoder_input_tokens=encoder_input_tokens, # only used for masks.
decoder_input_tokens=decoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
decoder_segment_ids=decoder_segment_ids,
training=training)
outputs["encoded"] = encoded
return outputs
@property
def checkpoint_items(self):
return dict(encoder=self.encoder, decoder=self.decoder)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import t5
def _create_cache(batch_size,
init_decode_length,
num_heads,
head_size,
dtype=tf.float32):
if num_heads is None:
kv_shape = [batch_size, init_decode_length, head_size]
else:
kv_shape = [batch_size, init_decode_length, num_heads, head_size]
return {
"key": tf.zeros(kv_shape, dtype=dtype),
"value": tf.zeros(kv_shape, dtype=dtype)
}
class ModulesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_embed(self, dtype):
l = t5.Embed(vocab_size=5, features=4, compute_dtype=dtype, name="foo")
inputs = np.array([[2, 3], [1, 2]], dtype=np.int32)
inputs = tf.convert_to_tensor(inputs)
one_hot_outputs = l(inputs, one_hot=True)
gather_outputs = l(inputs, one_hot=False)
self.assertEqual(one_hot_outputs.shape, (2, 2, 4))
self.assertLen(l.trainable_variables, 1)
self.assertAllClose(one_hot_outputs, gather_outputs)
outputs = l.attend(query=tf.zeros((2, 2, 4), dtype))
self.assertEqual(outputs.shape, (2, 2, 5))
# Test initializers.
l = t5.Embed(
vocab_size=5,
features=4,
compute_dtype=dtype,
name="foo",
embeddings_initializer=tf.keras.initializers.Zeros())
self.assertAllClose(l(inputs), tf.zeros((2, 2, 4), dtype))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_rms_norm(self, dtype):
l = t5.RMSNorm(hidden_size=4, epsilon=0.0, name="foo")
inputs = tf.ones((2, 4), dtype=dtype)
outputs = l(inputs)
self.assertAllEqual(l(inputs), inputs)
self.assertEqual(outputs.dtype, dtype)
self.assertLen(l.trainable_variables, 1)
self.assertIn("foo/scale", l.trainable_variables[0].name)
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_linear(self, dtype):
l = t5.Linear(
in_features=4,
out_features=4,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = tf.ones((2, 4), dtype=dtype)
outputs = l(inputs)
self.assertEqual(outputs.shape, inputs.shape)
self.assertEqual(outputs.dtype, dtype)
self.assertLen(l.trainable_variables, 2)
def test_linear3d(self):
batch_size = 2
l = t5.Linear3D(
in_features=4,
out_features=4,
num_heads=2,
to_3d=True,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = np.ones((batch_size, 2, 4), dtype=np.float32)
self.assertEqual(l(inputs).shape, (batch_size, 2, 2, 4))
l = t5.Linear3D(
in_features=2,
out_features=4,
num_heads=2,
to_3d=False,
w_init=tf.keras.initializers.Ones(),
name="foo")
inputs = np.ones((batch_size, 2, 2, 2), dtype=np.float32)
self.assertEqual(l(inputs).shape, (batch_size, 2, 4))
def test_ffn(self):
inputs = np.ones((2, 4), dtype=np.float32)
for activation in ["relu", "linear", "gelu", "swish"]:
l = t5.FFN(
d_model=4,
d_ff=8,
use_bias=True,
dropout_rate=0.1,
activations=[activation],
name="foo")
self.assertEqual(l(inputs).shape, inputs.shape)
self.assertLen(l.trainable_variables, 4)
l = t5.FFN(
d_model=4,
d_ff=8,
dropout_rate=0.1,
activations=["linear", "gelu"],
name="bar")
self.assertLen(l.trainable_variables, 3)
self.assertEqual(l(inputs).shape, inputs.shape)
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_relative_position(self, dtype):
l = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
compute_dtype=dtype,
name="foo")
self.assertEqual(l(4, 2).shape, (1, 4, 4, 2))
l = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
compute_dtype=dtype,
name="bar")
outputs = l(4, 2)
self.assertEqual(outputs.shape, (1, 4, 4, 2))
self.assertEqual(outputs.dtype, dtype)
def test_masks(self):
causal_mask = t5.make_causal_mask(np.zeros((2, 5)))
self.assertEqual(causal_mask.shape, (2, 1, 5, 5))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_attention(self, distribution):
num_heads, head_size = 2, 4
from_seq_length, to_seq_length = 4, 6
batch_size = 2
pos_embed = t5.RelativePositionEmbedding(
num_heads=4,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
name="pos_embed")
position_bias = pos_embed(from_seq_length, from_seq_length)
l = t5.MultiHeadAttention(d_model=4, d_kv=2, num_heads=4, dropout_rate=0.1)
query = tf.convert_to_tensor(
np.ones((batch_size, from_seq_length, 4), dtype=np.float32))
self.assertEqual(
l(query, position_bias=position_bias)["context"].shape, query.shape)
kv = tf.convert_to_tensor(
np.ones((batch_size, to_seq_length, 4), dtype=np.float32))
position_bias = pos_embed(from_seq_length, to_seq_length)
outputs = l(query, kv=kv, position_bias=position_bias)
self.assertEqual(outputs["context"].shape, query.shape)
with distribution.scope():
l = t5.MultiHeadAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
@tf.function
def step(inputs):
def _step_fn(inputs):
cache = _create_cache(batch_size, from_seq_length, num_heads,
head_size)
mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
return l(
query=inputs,
mask=mask,
cache=cache,
decode_position=decode_position)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
decode_position = 2
query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
local_outputs = step(query)
self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4))
self.assertNotEqual(
np.sum(local_outputs["cache"]["key"][0][:, decode_position,
...].numpy()), 0.0)
class T5Test(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
mode="eager"))
def test_attention_layers(self, distribution):
num_heads, head_size = 2, 2
from_seq_length = 4
# TPU decoding should pre-allocate the entire sequence.
batch_size = 2
with distribution.scope():
pos_embed = t5.RelativePositionEmbedding(
num_heads=head_size,
bidirectional=False,
embeddings_initializer=tf.keras.initializers.Ones(),
name="pos_embed")
l = t5.SelfAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
decode_position = 2
@tf.function
def step(inputs):
def _step_fn(inputs):
cache = _create_cache(batch_size, from_seq_length, num_heads,
head_size)
mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
position_bias = pos_embed(from_seq_length, from_seq_length)
return l(
hidden_states=inputs,
cache=cache,
attention_mask=mask,
decode_position=decode_position,
position_bias=position_bias)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
local_outputs = step(query)
self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4))
self.assertNotEqual(
np.sum(
local_outputs["cache"]["key"][0][:,
decode_position, :, :].numpy()),
0.0)
l = t5.CrossAttention(
d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
to_seq_length = 6
query = tf.convert_to_tensor(
np.ones((2, from_seq_length, 4), dtype=np.float32))
kv = tf.convert_to_tensor(
np.ones((2, to_seq_length, 4), dtype=np.float32))
@tf.function
def step_cross_attn(inputs):
def _step_fn(inputs):
query, kv = inputs
mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, to_seq_length)))
return l(hidden_states=query, kv=kv, attention_mask=mask)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
local_outputs = step_cross_attn((query, kv))
self.assertEqual(local_outputs["layer_output"][0].shape,
(2, from_seq_length, 4))
def test_encoder_block(self):
batch_size = 2
from_seq_length = 5
d_model = 4
l = t5.EncoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
pos_embed = t5.RelativePositionEmbedding(
num_heads=2,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
name="bar")
attention_mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, from_seq_length)))
position_bias = pos_embed(from_seq_length, from_seq_length)
inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
outputs = l(
inputs, attention_mask=attention_mask, position_bias=position_bias)
self.assertEqual(outputs.shape, (batch_size, from_seq_length, d_model))
def test_encdec_block(self):
batch_size = 2
from_seq_length = 5
to_seq_length = 3
d_model = 4
l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
pos_embed = t5.RelativePositionEmbedding(
num_heads=2,
bidirectional=True,
embeddings_initializer=tf.keras.initializers.Ones(),
name="bar")
encoder_decoder_mask = t5.make_attention_mask(
tf.ones((batch_size, from_seq_length)),
tf.ones((batch_size, to_seq_length)))
position_bias = pos_embed(from_seq_length, from_seq_length)
inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model),
dtype=tf.float32)
outputs = l(
inputs,
encoder_hidden_states,
encoder_decoder_mask=encoder_decoder_mask,
position_bias=position_bias)
self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(tf.zeros((4, 8), dtype=tf.int32))
self.assertEqual(encoded.shape, (4, 8, config.d_model))
def test_decoder(self):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
decoder = t5.Decoder(config)
batch_size = 4
targets = tf.zeros((4, 8), dtype=tf.int32)
encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
logits, cache = decoder(targets, encoded)
self.assertEqual(logits.shape, (4, 8, config.vocab_size))
cache = {}
cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv)
cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv)
targets = tf.zeros((4, 1), dtype=tf.int32)
logits, cache = decoder(
targets,
encoded,
decode_position=2,
cache=cache,
decode=True,
max_decode_len=max_decode_len)
self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
for entry in cache.values():
for tensor in entry.values():
self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),
("t5_11", ("gelu", "linear"), False, 29, False, tf.float32),
("t5_10_bfloat16", ("relu",), True, 26, False, tf.bfloat16),
("t5_11_bfloat16", ("gelu", "linear"), False, 29, False, tf.bfloat16),
("t5_10_layer_sharing", ("relu",), True, 26, True, tf.float32),
("t5_11_layer_sharing", ("gelu", "linear"), False, 29, True, tf.float32),
("t5_10_bfloat16_layer_sharing", ("relu",), True, 26, True, tf.bfloat16),
("t5_11_bfloat16_layer_sharing",
("gelu", "linear"), False, 29, True, tf.bfloat16))
def test_transformer(self, ffn_activations, logits_via_embedding,
expect_num_variables, layer_sharing, dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
def test_transformer_different_num_decoder_layers(self, ffn_activations,
logits_via_embedding,
expect_num_variables, dtype,
num_decoder_layers):
max_decode_len = 10
config = t5.T5TransformerParams(
num_decoder_layers=num_decoder_layers,
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
for i in range(num_decoder_layers):
cache[i] = _create_cache(
batch_size,
max_decode_len,
config.num_heads,
config.d_kv,
dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment