Commit 23db25e9 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[efficient] Promote bigbird to modeling/layers.

PiperOrigin-RevId: 374267447
parent 856622d3
...@@ -26,7 +26,6 @@ from official.modeling import hyperparams ...@@ -26,7 +26,6 @@ from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.projects.bigbird import attention as bigbird_attention
@dataclasses.dataclass @dataclasses.dataclass
...@@ -301,14 +300,14 @@ def build_encoder(config: EncoderConfig, ...@@ -301,14 +300,14 @@ def build_encoder(config: EncoderConfig,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
attention_cls=bigbird_attention.BigBirdAttention, attention_cls=layers.BigBirdAttention,
attention_cfg=attention_cfg) attention_cfg=attention_cfg)
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold, hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers, num_hidden_instances=encoder_cfg.num_layers,
mask_cls=bigbird_attention.BigBirdMasks, mask_cls=layers.BigBirdMasks,
mask_cfg=dict(block_size=encoder_cfg.block_size), mask_cfg=dict(block_size=encoder_cfg.block_size),
pooled_output_dim=encoder_cfg.hidden_size, pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
......
...@@ -8,6 +8,10 @@ assemble new `tf.keras` layers or models. ...@@ -8,6 +8,10 @@ assemble new `tf.keras` layers or models.
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention. `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [BigBirdAttention](bigbird_attention.py) implements a sparse attention
mechanism that reduces this quadratic dependency to linear described in
["Big Bird: Transformers for Longer Sequences"](https://arxiv.org/abs/2007.14062).
* [CachedAttention](attention.py) implements an attention layer with cache * [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding. used for auto-agressive decoding.
...@@ -80,20 +84,20 @@ assemble new `tf.keras` layers or models. ...@@ -80,20 +84,20 @@ assemble new `tf.keras` layers or models.
* [MultiHeadRelativeAttention](relative_attention.py) implements a variant * [MultiHeadRelativeAttention](relative_attention.py) implements a variant
of multi-head attention with support for relative position encodings as of multi-head attention with support for relative position encodings as
described in "Transformer-XL: Attentive Language Models Beyond a described in ["Transformer-XL: Attentive Language Models Beyond a
Fixed-Length Context"(https://arxiv.org/abs/1901.02860). This also has Fixed-Length Context"](https://arxiv.org/abs/1901.02860). This also has
extended support for segment-based attention, a re-parameterization extended support for segment-based attention, a re-parameterization
introduced in "XLNet: Generalized Autoregressive Pretraining for Language introduced in ["XLNet: Generalized Autoregressive Pretraining for Language
Understanding" (https://arxiv.org/abs/1906.08237). Understanding"](https://arxiv.org/abs/1906.08237).
* [TwoStreamRelativeAttention](relative_attention.py) implements a variant * [TwoStreamRelativeAttention](relative_attention.py) implements a variant
of multi-head relative attention as described in "XLNet: Generalized of multi-head relative attention as described in ["XLNet: Generalized
Autoregressive Pretraining for Language Understanding" Autoregressive Pretraining for Language Understanding"]
(https://arxiv.org/abs/1906.08237). This takes in a query and content (https://arxiv.org/abs/1906.08237). This takes in a query and content
stream and applies self attention. stream and applies self attention.
* [TransformerXL](transformer_xl.py) implements Transformer XL introduced in * [TransformerXL](transformer_xl.py) implements Transformer XL introduced in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" ["Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"]
(https://arxiv.org/abs/1901.02860). This contains `TransformerXLBlock`, a (https://arxiv.org/abs/1901.02860). This contains `TransformerXLBlock`, a
block containing either one or two stream relative self-attention as well as block containing either one or two stream relative self-attention as well as
subsequent feedforward networks. It also contains `TransformerXL`, which subsequent feedforward networks. It also contains `TransformerXL`, which
......
...@@ -18,6 +18,8 @@ They can be used to assemble new `tf.keras` layers or models. ...@@ -18,6 +18,8 @@ They can be used to assemble new `tf.keras` layers or models.
""" """
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from official.nlp.modeling.layers.attention import * from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
from official.nlp.modeling.layers.bigbird_attention import BigBirdMasks
from official.nlp.modeling.layers.cls_head import * from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import tensorflow as tf import tensorflow as tf
from official.nlp.projects.bigbird import attention from official.nlp.modeling.layers import bigbird_attention as attention
class BigbirdAttentionTest(tf.test.TestCase): class BigbirdAttentionTest(tf.test.TestCase):
......
...@@ -20,11 +20,13 @@ import tensorflow as tf ...@@ -20,11 +20,13 @@ import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.projects.bigbird import attention
from official.nlp.projects.bigbird import recompute_grad from official.nlp.projects.bigbird import recompute_grad
from official.nlp.projects.bigbird import recomputing_dropout from official.nlp.projects.bigbird import recomputing_dropout
_MAX_SEQ_LEN = 4096
class RecomputeTransformerLayer(layers.TransformerScaffold): class RecomputeTransformerLayer(layers.TransformerScaffold):
"""Transformer layer that recomputes the forward pass during backpropagation.""" """Transformer layer that recomputes the forward pass during backpropagation."""
...@@ -86,7 +88,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -86,7 +88,7 @@ class BigBirdEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
max_position_embeddings=attention.MAX_SEQ_LEN, max_position_embeddings=_MAX_SEQ_LEN,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
block_size=64, block_size=64,
...@@ -177,7 +179,8 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -177,7 +179,8 @@ class BigBirdEncoder(tf.keras.Model):
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(data, mask) masks = layers.BigBirdMasks(block_size=block_size)(
data, mask)
encoder_outputs = [] encoder_outputs = []
attn_head_dim = hidden_size // num_attention_heads attn_head_dim = hidden_size // num_attention_heads
for i in range(num_layers): for i in range(num_layers):
...@@ -185,7 +188,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -185,7 +188,7 @@ class BigBirdEncoder(tf.keras.Model):
num_attention_heads, num_attention_heads,
intermediate_size, intermediate_size,
activation, activation,
attention_cls=attention.BigBirdAttention, attention_cls=layers.BigBirdAttention,
attention_cfg=dict( attention_cfg=dict(
num_heads=num_attention_heads, num_heads=num_attention_heads,
key_dim=attn_head_dim, key_dim=attn_head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment