Commit baf94acc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480424926
parent 09c0b474
......@@ -17,7 +17,7 @@
Includes configurations and factory methods.
"""
import dataclasses
from typing import Optional
from typing import Optional, Sequence
import gin
import tensorflow as tf
......@@ -242,6 +242,29 @@ class QueryBertConfig(hyperparams.Config):
norm_first: bool = False
@dataclasses.dataclass
class FNetEncoderConfig(hyperparams.Config):
"""FNet encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
inner_activation: str = "gelu"
inner_dim: int = 3072
output_dropout: float = 0.1
attention_dropout: float = 0.1
max_sequence_length: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_width: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
use_fft: bool = False
attention_layers: Sequence[int] = ()
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
......@@ -255,6 +278,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
query_bert: QueryBertConfig = QueryBertConfig()
fnet: FNetEncoderConfig = FNetEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
......@@ -562,6 +586,27 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True,
norm_first=encoder_cfg.norm_first)
if encoder_type == "fnet":
return networks.FNet(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
embedding_layer=embedding_layer,
norm_first=encoder_cfg.norm_first,
use_fft=encoder_cfg.use_fft,
attention_layers=encoder_cfg.attention_layers)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
......
......@@ -43,7 +43,7 @@ class FNet(tf.keras.layers.Layer):
This implementation defaults to the canonical FNet Base model, but the network
also supports more general mixing models (e.g. 'Linear', 'HNet') and hybrid
models (e.g. 'FNet-Hybrid') models that use both mixing and self-attention
layers.
layers. The input length is fixed to 'max_sequence_length'.
Args:
vocab_size: The size of the token vocabulary.
......@@ -61,8 +61,9 @@ class FNet(tf.keras.layers.Layer):
good rule of thumb is to place them in the final few layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. This determines the variable shape for positional embeddings.
max_sequence_length: The only sequence length that this encoder can
consume. This determines the variable shape for positional embeddings and
the size of the mixing matrices.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
......@@ -220,19 +221,28 @@ class FNet(tf.keras.layers.Layer):
if with_dense_inputs:
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(None, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
shape=(max_sequence_length, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
)
else:
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32))
self._max_sequence_length = max_sequence_length
def call(self, inputs):
word_embeddings = None
......@@ -258,6 +268,12 @@ class FNet(tf.keras.layers.Layer):
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
seq_length = word_embeddings.shape[1]
if seq_length != self._max_sequence_length:
raise ValueError('FNet: Sequence length must be the same as '
'`max_sequence_length` ({}), but it is {}.'.format(
self._max_sequence_length, seq_length))
# Absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
......
......@@ -47,6 +47,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase):
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=num_layers,
mixing_mechanism=mixing_mechanism,
attention_layers=attention_layers)
......@@ -81,6 +82,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase):
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=3)
# Create the inputs (note that the first dimension is implicit).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment