Unverified Commit 6f257bb3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Update esmfold conversion script (#20028)

* Update ESM conversion script for ESMfold

* Fix bug in ESMFold example

* make fixup and move restypes to one line
parent 2564f0c2
...@@ -23,6 +23,7 @@ from tempfile import TemporaryDirectory ...@@ -23,6 +23,7 @@ from tempfile import TemporaryDirectory
import torch import torch
import esm as esm_module import esm as esm_module
from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences
from esm.esmfold.v1.pretrained import esmfold_v1 from esm.esmfold.v1.pretrained import esmfold_v1
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
from transformers.models.esm.modeling_esm import ( from transformers.models.esm.modeling_esm import (
...@@ -43,7 +44,10 @@ logging.set_verbosity_info() ...@@ -43,7 +44,10 @@ logging.set_verbosity_info()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SAMPLE_DATA = [ SAMPLE_DATA = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), (
"protein1",
"MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
),
("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"), ("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"),
("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"), ("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"),
...@@ -65,6 +69,21 @@ MODEL_MAPPING = { ...@@ -65,6 +69,21 @@ MODEL_MAPPING = {
"esmfold_v1": esmfold_v1, "esmfold_v1": esmfold_v1,
} }
restypes = list("ARNDCQEGHILKMFPSTWYV")
restypes_with_x = restypes + ["X"]
restypes_with_extras = restypes_with_x + ["<pad>", "<mask>", "<cls>", "<sep>", "<eos>"]
def get_esmfold_tokenizer():
with TemporaryDirectory() as tempdir:
vocab = "\n".join(restypes_with_extras)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokenizer.pad_token_id = 0 # Overlaps with 'A' but that seems to be what they want
return hf_tokenizer
def transfer_and_check_weights(original_module, our_module): def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict()) status = our_module.load_state_dict(original_module.state_dict())
...@@ -82,7 +101,6 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -82,7 +101,6 @@ def convert_esm_checkpoint_to_pytorch(
""" """
if model.startswith("esmfold"): if model.startswith("esmfold"):
esm = MODEL_MAPPING[model]() esm = MODEL_MAPPING[model]()
alphabet = esm.esm.alphabet
else: else:
esm, alphabet = MODEL_MAPPING[model]() esm, alphabet = MODEL_MAPPING[model]()
esm.eval() # disable dropout esm.eval() # disable dropout
...@@ -129,7 +147,11 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -129,7 +147,11 @@ def convert_esm_checkpoint_to_pytorch(
is_folding_model = False is_folding_model = False
esmfold_config = None esmfold_config = None
if is_folding_model:
alphabet = esm.esm.alphabet
vocab_list = tuple(alphabet.all_toks) vocab_list = tuple(alphabet.all_toks)
mask_token_id = alphabet.mask_idx
pad_token_id = alphabet.padding_idx
if is_folding_model: if is_folding_model:
original_esm_model = esm.esm original_esm_model = esm.esm
...@@ -138,7 +160,7 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -138,7 +160,7 @@ def convert_esm_checkpoint_to_pytorch(
config = EsmConfig( config = EsmConfig(
vocab_size=original_esm_model.embed_tokens.num_embeddings, vocab_size=original_esm_model.embed_tokens.num_embeddings,
mask_token_id=alphabet.mask_idx, mask_token_id=mask_token_id,
hidden_size=embed_dim, hidden_size=embed_dim,
num_hidden_layers=num_layers, num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -147,7 +169,7 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -147,7 +169,7 @@ def convert_esm_checkpoint_to_pytorch(
layer_norm_eps=1e-5, # PyTorch default used in fairseq layer_norm_eps=1e-5, # PyTorch default used in fairseq
attention_probs_dropout_prob=0.0, attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0, hidden_dropout_prob=0.0,
pad_token_id=alphabet.padding_idx, pad_token_id=pad_token_id,
emb_layer_norm_before=emb_layer_norm_before, emb_layer_norm_before=emb_layer_norm_before,
token_dropout=token_dropout, token_dropout=token_dropout,
position_embedding_type=position_embedding_type, position_embedding_type=position_embedding_type,
...@@ -239,6 +261,7 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -239,6 +261,7 @@ def convert_esm_checkpoint_to_pytorch(
if is_folding_model: if is_folding_model:
model.esm_s_combine.data = esm.esm_s_combine.data model.esm_s_combine.data = esm.esm_s_combine.data
model.af2_to_esm.data = esm.af2_to_esm.data
transfer_and_check_weights(esm.embedding, model.embedding) transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp) transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk) transfer_and_check_weights(esm.trunk, model.trunk)
...@@ -261,9 +284,6 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -261,9 +284,6 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.decoder.weight = esm.lm_head.weight model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.bias = esm.lm_head.bias model.lm_head.bias = esm.lm_head.bias
# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model: if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens. # Folding models aren't trained on masked inputs and don't like mask tokens.
...@@ -271,16 +291,29 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -271,16 +291,29 @@ def convert_esm_checkpoint_to_pytorch(
else: else:
sample_data = SAMPLE_DATA sample_data = SAMPLE_DATA
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data) if is_folding_model:
# Prepare tokenizer and make sure it matches hf_tokenizer = get_esmfold_tokenizer()
with TemporaryDirectory() as tempdir: hf_tokens = hf_tokenizer(
vocab = "\n".join(alphabet.all_toks) [row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
vocab_file = Path(tempdir) / "vocab.txt" )
vocab_file.write_text(vocab) esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file)) success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
hf_tokens["attention_mask"] == esmfold_mask
)
else:
# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
# Prepare tokenizer and make sure it matches
with TemporaryDirectory() as tempdir:
vocab = "\n".join(alphabet.all_toks)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩") print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success: if not success:
raise Exception("Tokenization does not match!") raise Exception("Tokenization does not match!")
...@@ -292,10 +325,10 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -292,10 +325,10 @@ def convert_esm_checkpoint_to_pytorch(
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However, # that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the # ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
# original and the converted model on the GPU at the same time. # original and the converted model on the GPU at the same time.
their_output = esm.cuda().infer([row[1] for row in sample_data])
our_output = model.cuda()( our_output = model.cuda()(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda() input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
) )
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
else: else:
our_output = model(**hf_tokens, output_hidden_states=True) our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"] our_output = our_output["logits"]
...@@ -322,23 +355,7 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -322,23 +355,7 @@ def convert_esm_checkpoint_to_pytorch(
print(f"Saving model to {pytorch_dump_folder_path}") print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda() del esm # Free up some memory before continuing
reloaded_output = reloaded(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6)
else:
max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item()
success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6)
print(f"max_absolute_diff = {max_absolute_diff}")
print("Does the model output the same tensors after reloading?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
print(f"Saving tokenizer to {pytorch_dump_folder_path}") print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path) hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
......
...@@ -2100,7 +2100,7 @@ class EsmForProteinFolding(EsmPreTrainedModel): ...@@ -2100,7 +2100,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
>>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt") # A tiny random peptide >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> folded_positions = outputs.positions >>> folded_positions = outputs.positions
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment