Unverified Commit abc400b0 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Add final_layer_norm to OPT model (#17785)



* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 52404cba
...@@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig): ...@@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig):
ffn_dim=3072, ffn_dim=3072,
max_position_embeddings=2048, max_position_embeddings=2048,
do_layer_norm_before=True, do_layer_norm_before=True,
_remove_final_layer_norm=False,
word_embed_proj_dim=None, word_embed_proj_dim=None,
dropout=0.1, dropout=0.1,
attention_dropout=0.0, attention_dropout=0.0,
...@@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig): ...@@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig):
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.use_cache = use_cache self.use_cache = use_cache
self.do_layer_norm_before = do_layer_norm_before self.do_layer_norm_before = do_layer_norm_before
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
self._remove_final_layer_norm = _remove_final_layer_norm
...@@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path): ...@@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path):
# pop unnecessary weights # pop unnecessary weights
keys_to_delete = [ keys_to_delete = [
"decoder.version", "decoder.version",
"decoder.layer_norm.weight",
"decoder.layer_norm.bias",
"decoder.output_projection.weight", "decoder.output_projection.weight",
] ]
for key in keys_to_delete: for key in keys_to_delete:
...@@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path): ...@@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path):
keys_to_rename = { keys_to_rename = {
"decoder.project_in_dim.weight": "decoder.project_in.weight", "decoder.project_in_dim.weight": "decoder.project_in.weight",
"decoder.project_out_dim.weight": "decoder.project_out.weight", "decoder.project_out_dim.weight": "decoder.project_out.weight",
"decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
"decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
} }
for old_key, new_key in keys_to_rename.items(): for old_key, new_key in keys_to_rename.items():
if old_key in sd: if old_key in sd:
......
...@@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module): ...@@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module):
self.project_in = None self.project_in = None
self.project_out = None self.project_out = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
else:
self.final_layer_norm = None
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
def __call__( def __call__(
...@@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module): ...@@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
if self.final_layer_norm is not None:
hidden_state = self.final_layer_norm(hidden_state)
if self.project_out is not None: if self.project_out is not None:
hidden_state = self.project_out(hidden_state) hidden_state = self.project_out(hidden_state)
......
...@@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel):
else: else:
self.project_in = None self.project_in = None
self.layer_norm = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)
......
...@@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer):
name="embed_positions", name="embed_positions",
) )
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
else:
self.final_layer_norm = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
...@@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)
......
...@@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase): ...@@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []
......
...@@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase): ...@@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []
......
...@@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase): ...@@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment