Commit 097a8296 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp] Support norm_first option for Bert/Bigbird/Kernel encoders.

PiperOrigin-RevId: 382960239
parent c2582c3e
...@@ -46,6 +46,8 @@ class BertEncoderConfig(hyperparams.Config): ...@@ -46,6 +46,8 @@ class BertEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
output_range: Optional[int] = None output_range: Optional[int] = None
return_all_encoder_outputs: bool = False return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -132,6 +134,8 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -132,6 +134,8 @@ class BigBirdEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 4096 max_position_embeddings: int = 4096
num_rand_blocks: int = 3 num_rand_blocks: int = 3
block_size: int = 64 block_size: int = 64
...@@ -152,6 +156,8 @@ class KernelEncoderConfig(hyperparams.Config): ...@@ -152,6 +156,8 @@ class KernelEncoderConfig(hyperparams.Config):
intermediate_size: int = 3072 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
...@@ -340,6 +346,7 @@ def build_encoder(config: EncoderConfig, ...@@ -340,6 +346,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
attention_cls=layers.BigBirdAttention, attention_cls=layers.BigBirdAttention,
...@@ -387,6 +394,7 @@ def build_encoder(config: EncoderConfig, ...@@ -387,6 +394,7 @@ def build_encoder(config: EncoderConfig,
encoder_cfg.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention, attention_cls=layers.KernelAttention,
...@@ -447,4 +455,5 @@ def build_encoder(config: EncoderConfig, ...@@ -447,4 +455,5 @@ def build_encoder(config: EncoderConfig,
embedding_width=encoder_cfg.embedding_size, embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer, embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True) dict_outputs=True,
norm_first=encoder_cfg.norm_first)
...@@ -69,6 +69,9 @@ class BertEncoder(tf.keras.Model): ...@@ -69,6 +69,9 @@ class BertEncoder(tf.keras.Model):
smaller than 'hidden_size'). smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to embedding_layer: An optional Layer instance which will be called to
generate embeddings for the input word IDs. generate embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
""" """
def __init__( def __init__(
...@@ -87,6 +90,7 @@ class BertEncoder(tf.keras.Model): ...@@ -87,6 +90,7 @@ class BertEncoder(tf.keras.Model):
output_range=None, output_range=None,
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
norm_first=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -162,6 +166,7 @@ class BertEncoder(tf.keras.Model): ...@@ -162,6 +166,7 @@ class BertEncoder(tf.keras.Model):
inner_activation=inner_activation, inner_activation=inner_activation,
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=initializer,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
...@@ -211,6 +216,7 @@ class BertEncoder(tf.keras.Model): ...@@ -211,6 +216,7 @@ class BertEncoder(tf.keras.Model):
'output_range': output_range, 'output_range': output_range,
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first,
} }
# We are storing the config dict as a namedtuple here to ensure checkpoint # We are storing the config dict as a namedtuple here to ensure checkpoint
......
...@@ -205,7 +205,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -205,7 +205,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
initializer="glorot_uniform", initializer="glorot_uniform",
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
embedding_layer=None) embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["inner_activation"] = tf.keras.activations.serialize( expected_config["inner_activation"] = tf.keras.activations.serialize(
......
...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings. generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
""" """
def __init__(self, def __init__(self,
...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
dict_outputs=False, dict_outputs=False,
norm_first=False,
**kwargs): **kwargs):
# b/164516224 # b/164516224
...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer=initializer, initializer=initializer,
output_range=output_range, output_range=output_range,
embedding_width=embedding_width, embedding_width=embedding_width,
embedding_layer=embedding_layer) embedding_layer=embedding_layer,
norm_first=norm_first)
self._embedding_layer_instance = embedding_layer self._embedding_layer_instance = embedding_layer
......
...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
dict_outputs=True, dict_outputs=True,
embedding_layer=None) embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment