Commit 097a8296 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp] Support norm_first option for Bert/Bigbird/Kernel encoders.

PiperOrigin-RevId: 382960239
parent c2582c3e
......@@ -46,6 +46,8 @@ class BertEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass
......@@ -132,6 +134,8 @@ class BigBirdEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 4096
num_rand_blocks: int = 3
block_size: int = 64
......@@ -152,6 +156,8 @@ class KernelEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
......@@ -340,6 +346,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.BigBirdAttention,
......@@ -387,6 +394,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention,
......@@ -447,4 +455,5 @@ def build_encoder(config: EncoderConfig,
embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
dict_outputs=True,
norm_first=encoder_cfg.norm_first)
......@@ -69,6 +69,9 @@ class BertEncoder(tf.keras.Model):
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
def __init__(
......@@ -87,6 +90,7 @@ class BertEncoder(tf.keras.Model):
output_range=None,
embedding_width=None,
embedding_layer=None,
norm_first=False,
**kwargs):
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
......@@ -162,6 +166,7 @@ class BertEncoder(tf.keras.Model):
inner_activation=inner_activation,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range,
kernel_initializer=initializer,
name='transformer/layer_%d' % i)
......@@ -211,6 +216,7 @@ class BertEncoder(tf.keras.Model):
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
......
......@@ -205,7 +205,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
initializer="glorot_uniform",
output_range=-1,
embedding_width=16,
embedding_layer=None)
embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize(
......
......@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
def __init__(self,
......@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width=None,
embedding_layer=None,
dict_outputs=False,
norm_first=False,
**kwargs):
# b/164516224
......@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer=initializer,
output_range=output_range,
embedding_width=embedding_width,
embedding_layer=embedding_layer)
embedding_layer=embedding_layer,
norm_first=norm_first)
self._embedding_layer_instance = embedding_layer
......
......@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1,
embedding_width=16,
dict_outputs=True,
embedding_layer=None)
embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment