Commit 842eb979 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374640999
parent c035325f
...@@ -140,6 +140,28 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -140,6 +140,28 @@ class BigBirdEncoderConfig(hyperparams.Config):
use_gradient_checkpointing: bool = False use_gradient_checkpointing: bool = False
@dataclasses.dataclass
class KernelEncoderConfig(hyperparams.Config):
"""Linear encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
feature_transform: str = "exp"
num_random_features: int = 256
redraw: bool = False
is_short_seq: bool = False
begin_kernel: int = 0
@dataclasses.dataclass @dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config): class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration.""" """XLNet encoder configuration."""
...@@ -172,6 +194,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -172,6 +194,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
albert: AlbertEncoderConfig = AlbertEncoderConfig() albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
...@@ -317,6 +340,51 @@ def build_encoder(config: EncoderConfig, ...@@ -317,6 +340,51 @@ def build_encoder(config: EncoderConfig,
layer_idx_as_attention_seed=True) layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs) return networks.EncoderScaffold(**kwargs)
if encoder_type == "kernel":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
feature_transform=encoder_cfg.feature_transform,
num_random_features=encoder_cfg.num_random_features,
redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention,
attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=layers.KernelMask,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "xlnet": if encoder_type == "xlnet":
return networks.XLNetBase( return networks.XLNetBase(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
......
...@@ -25,6 +25,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum ...@@ -25,6 +25,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess
from official.nlp.modeling.layers.kernel_attention import KernelAttention from official.nlp.modeling.layers.kernel_attention import KernelAttention
from official.nlp.modeling.layers.kernel_attention import KernelMask
from official.nlp.modeling.layers.masked_lm import MaskedLM from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin from official.nlp.modeling.layers.mat_mul_with_margin import MatMulWithMargin
......
...@@ -21,6 +21,24 @@ import tensorflow as tf ...@@ -21,6 +21,24 @@ import tensorflow as tf
_NUMERIC_STABLER = 1e-6 _NUMERIC_STABLER = 1e-6
class KernelMask(tf.keras.layers.Layer):
"""Creates kernel attention mask.
inputs: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
mask: a Tensor of shape [batch_size, from_seq_length] which indicates
which part of the inputs we should not attend.
Returns:
float Tensor of shape [batch_size, from_seq_length] that KernelAttention
takes as mask.
"""
def call(self, inputs, mask):
mask = tf.cast(mask, inputs.dtype)
return mask
def create_projection_matrix(m, d, seed=None): def create_projection_matrix(m, d, seed=None):
r"""Constructs the matrix of random projections. r"""Constructs the matrix of random projections.
...@@ -248,7 +266,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -248,7 +266,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
short or long sequences; usually short sequence is defined as having short or long sequences; usually short sequence is defined as having
length L <= 1024. length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention to certain positions. Note that the mask is only appied to attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads. the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing). training mode (adding dropout) or in inference mode (doing nothing).
...@@ -305,8 +323,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -305,8 +323,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value, value,
key=None, key=None,
attention_mask=None, attention_mask=None,
training=False, training=False):
**kwargs): """Compute attention with kernel mechanism.
Args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if not self._built_from_signature: if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key) self._build_from_signature(query=query, value=value, key=key)
if key is None: if key is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment