"vscode:/vscode.git/clone" did not exist on "dde4b02c18d7695bb23ad0eef3d14e006a52c0a1"
Commit ad0c466e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374640999
parent 7d5b6be3
......@@ -140,6 +140,28 @@ class BigBirdEncoderConfig(hyperparams.Config):
use_gradient_checkpointing: bool = False
@dataclasses.dataclass
class KernelEncoderConfig(hyperparams.Config):
"""Linear encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
feature_transform: str = "exp"
num_random_features: int = 256
redraw: bool = False
is_short_seq: bool = False
begin_kernel: int = 0
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
......@@ -172,6 +194,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
......@@ -317,6 +340,51 @@ def build_encoder(config: EncoderConfig,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "kernel":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
feature_transform=encoder_cfg.feature_transform,
num_random_features=encoder_cfg.num_random_features,
redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention,
attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=layers.KernelMask,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "xlnet":
return networks.XLNetBase(
vocab_size=encoder_cfg.vocab_size,
......
......@@ -25,6 +25,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess
from official.nlp.modeling.layers.kernel_attention import KernelAttention
from official.nlp.modeling.layers.kernel_attention import KernelMask
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
......
......@@ -21,6 +21,24 @@ import tensorflow as tf
_NUMERIC_STABLER = 1e-6
class KernelMask(tf.keras.layers.Layer):
"""Creates kernel attention mask.
inputs: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
mask: a Tensor of shape [batch_size, from_seq_length] which indicates
which part of the inputs we should not attend.
Returns:
float Tensor of shape [batch_size, from_seq_length] that KernelAttention
takes as mask.
"""
def call(self, inputs, mask):
mask = tf.cast(mask, inputs.dtype)
return mask
def create_projection_matrix(m, d, seed=None):
r"""Constructs the matrix of random projections.
......@@ -248,7 +266,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
short or long sequences; usually short sequence is defined as having
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention to certain positions. Note that the mask is only appied to
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
......@@ -305,8 +323,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value,
key=None,
attention_mask=None,
training=False,
**kwargs):
training=False):
"""Compute attention with kernel mechanism.
Args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment