Unverified Commit fd5eac5f authored by Matt's avatar Matt Committed by GitHub
Browse files

Small fixes for TF-ESM1b and ESM-1b weight conversions (#19683)

parent 90071fe4
...@@ -149,7 +149,7 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -149,7 +149,7 @@ def convert_esm_checkpoint_to_pytorch(
self_attn.value.weight.data = esm_layer.self_attn.v_proj.weight self_attn.value.weight.data = esm_layer.self_attn.v_proj.weight
self_attn.value.bias.data = esm_layer.self_attn.v_proj.bias self_attn.value.bias.data = esm_layer.self_attn.v_proj.bias
if hasattr(esm_layer.self_attn, "rot_emb"): if getattr(esm_layer.self_attn, "rot_emb", None) is not None:
# Matt: Although inv_freq is not a trainable weight, it is computed at model init and cached. # Matt: Although inv_freq is not a trainable weight, it is computed at model init and cached.
# During the training of ESM-2 the model was converted to float16 precision, which also converts # During the training of ESM-2 the model was converted to float16 precision, which also converts
# the inv_freq tensor, and the loss of precision remains even if the model is loaded later as float32. # the inv_freq tensor, and the loss of precision remains even if the model is loaded later as float32.
......
...@@ -136,7 +136,7 @@ class TFEsmEmbeddings(Layer): ...@@ -136,7 +136,7 @@ class TFEsmEmbeddings(Layer):
) )
if config.emb_layer_norm_before: if config.emb_layer_norm_before:
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps) self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
else: else:
self.layer_norm = None self.layer_norm = None
# Matt: I think this line was copied incorrectly from BERT, disabling for now # Matt: I think this line was copied incorrectly from BERT, disabling for now
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment