Commit baf94acc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480424926
parent 09c0b474
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
Includes configurations and factory methods. Includes configurations and factory methods.
""" """
import dataclasses import dataclasses
from typing import Optional from typing import Optional, Sequence
import gin import gin
import tensorflow as tf import tensorflow as tf
...@@ -242,6 +242,29 @@ class QueryBertConfig(hyperparams.Config): ...@@ -242,6 +242,29 @@ class QueryBertConfig(hyperparams.Config):
norm_first: bool = False norm_first: bool = False
@dataclasses.dataclass
class FNetEncoderConfig(hyperparams.Config):
"""FNet encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
inner_activation: str = "gelu"
inner_dim: int = 3072
output_dropout: float = 0.1
attention_dropout: float = 0.1
max_sequence_length: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_width: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
use_fft: bool = False
attention_layers: Sequence[int] = ()
@dataclasses.dataclass @dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig): class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration.""" """Encoder configuration."""
...@@ -255,6 +278,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -255,6 +278,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
reuse: ReuseEncoderConfig = ReuseEncoderConfig() reuse: ReuseEncoderConfig = ReuseEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
query_bert: QueryBertConfig = QueryBertConfig() query_bert: QueryBertConfig = QueryBertConfig()
fnet: FNetEncoderConfig = FNetEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER. # If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config() any: hyperparams.Config = hyperparams.Config()
...@@ -562,6 +586,27 @@ def build_encoder(config: EncoderConfig, ...@@ -562,6 +586,27 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True, dict_outputs=True,
norm_first=encoder_cfg.norm_first) norm_first=encoder_cfg.norm_first)
if encoder_type == "fnet":
return networks.FNet(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
embedding_layer=embedding_layer,
norm_first=encoder_cfg.norm_first,
use_fft=encoder_cfg.use_fft,
attention_layers=encoder_cfg.attention_layers)
bert_encoder_cls = networks.BertEncoder bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2": if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2 bert_encoder_cls = networks.BertEncoderV2
......
...@@ -43,7 +43,7 @@ class FNet(tf.keras.layers.Layer): ...@@ -43,7 +43,7 @@ class FNet(tf.keras.layers.Layer):
This implementation defaults to the canonical FNet Base model, but the network This implementation defaults to the canonical FNet Base model, but the network
also supports more general mixing models (e.g. 'Linear', 'HNet') and hybrid also supports more general mixing models (e.g. 'Linear', 'HNet') and hybrid
models (e.g. 'FNet-Hybrid') models that use both mixing and self-attention models (e.g. 'FNet-Hybrid') models that use both mixing and self-attention
layers. layers. The input length is fixed to 'max_sequence_length'.
Args: Args:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
...@@ -61,8 +61,9 @@ class FNet(tf.keras.layers.Layer): ...@@ -61,8 +61,9 @@ class FNet(tf.keras.layers.Layer):
good rule of thumb is to place them in the final few layers. good rule of thumb is to place them in the final few layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can max_sequence_length: The only sequence length that this encoder can
consume. This determines the variable shape for positional embeddings. consume. This determines the variable shape for positional embeddings and
the size of the mixing matrices.
type_vocab_size: The number of types that the 'type_ids' input can take. type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer. feedforward network for each transformer.
...@@ -220,19 +221,28 @@ class FNet(tf.keras.layers.Layer): ...@@ -220,19 +221,28 @@ class FNet(tf.keras.layers.Layer):
if with_dense_inputs: if with_dense_inputs:
self.inputs = dict( self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), input_word_ids=tf.keras.Input(
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_inputs=tf.keras.Input( dense_inputs=tf.keras.Input(
shape=(None, embedding_width), dtype=tf.float32), shape=(max_sequence_length, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), dense_mask=tf.keras.Input(
dense_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), shape=(max_sequence_length,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
) )
else: else:
self.inputs = dict( self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), input_word_ids=tf.keras.Input(
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32)) input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32))
self._max_sequence_length = max_sequence_length
def call(self, inputs): def call(self, inputs):
word_embeddings = None word_embeddings = None
...@@ -258,6 +268,12 @@ class FNet(tf.keras.layers.Layer): ...@@ -258,6 +268,12 @@ class FNet(tf.keras.layers.Layer):
type_ids = tf.concat([type_ids, dense_type_ids], axis=1) type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1) mask = tf.concat([mask, dense_mask], axis=1)
seq_length = word_embeddings.shape[1]
if seq_length != self._max_sequence_length:
raise ValueError('FNet: Sequence length must be the same as '
'`max_sequence_length` ({}), but it is {}.'.format(
self._max_sequence_length, seq_length))
# Absolute position embeddings. # Absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = self._type_embedding_layer(type_ids)
......
...@@ -47,6 +47,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -47,6 +47,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase):
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=num_layers, num_layers=num_layers,
mixing_mechanism=mixing_mechanism, mixing_mechanism=mixing_mechanism,
attention_layers=attention_layers) attention_layers=attention_layers)
...@@ -81,6 +82,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -81,6 +82,7 @@ class FNetTest(parameterized.TestCase, tf.test.TestCase):
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=3) num_layers=3)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment