Unverified Commit 6f257bb3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Update esmfold conversion script (#20028)

* Update ESM conversion script for ESMfold

* Fix bug in ESMFold example

* make fixup and move restypes to one line
parent 2564f0c2
......@@ -23,6 +23,7 @@ from tempfile import TemporaryDirectory
import torch
import esm as esm_module
from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences
from esm.esmfold.v1.pretrained import esmfold_v1
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
from transformers.models.esm.modeling_esm import (
......@@ -43,7 +44,10 @@ logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_DATA = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
(
"protein1",
"MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
),
("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"),
("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"),
......@@ -65,6 +69,21 @@ MODEL_MAPPING = {
"esmfold_v1": esmfold_v1,
}
restypes = list("ARNDCQEGHILKMFPSTWYV")
restypes_with_x = restypes + ["X"]
restypes_with_extras = restypes_with_x + ["<pad>", "<mask>", "<cls>", "<sep>", "<eos>"]
def get_esmfold_tokenizer():
with TemporaryDirectory() as tempdir:
vocab = "\n".join(restypes_with_extras)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokenizer.pad_token_id = 0 # Overlaps with 'A' but that seems to be what they want
return hf_tokenizer
def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict())
......@@ -82,7 +101,6 @@ def convert_esm_checkpoint_to_pytorch(
"""
if model.startswith("esmfold"):
esm = MODEL_MAPPING[model]()
alphabet = esm.esm.alphabet
else:
esm, alphabet = MODEL_MAPPING[model]()
esm.eval() # disable dropout
......@@ -129,7 +147,11 @@ def convert_esm_checkpoint_to_pytorch(
is_folding_model = False
esmfold_config = None
if is_folding_model:
alphabet = esm.esm.alphabet
vocab_list = tuple(alphabet.all_toks)
mask_token_id = alphabet.mask_idx
pad_token_id = alphabet.padding_idx
if is_folding_model:
original_esm_model = esm.esm
......@@ -138,7 +160,7 @@ def convert_esm_checkpoint_to_pytorch(
config = EsmConfig(
vocab_size=original_esm_model.embed_tokens.num_embeddings,
mask_token_id=alphabet.mask_idx,
mask_token_id=mask_token_id,
hidden_size=embed_dim,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
......@@ -147,7 +169,7 @@ def convert_esm_checkpoint_to_pytorch(
layer_norm_eps=1e-5, # PyTorch default used in fairseq
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
pad_token_id=alphabet.padding_idx,
pad_token_id=pad_token_id,
emb_layer_norm_before=emb_layer_norm_before,
token_dropout=token_dropout,
position_embedding_type=position_embedding_type,
......@@ -239,6 +261,7 @@ def convert_esm_checkpoint_to_pytorch(
if is_folding_model:
model.esm_s_combine.data = esm.esm_s_combine.data
model.af2_to_esm.data = esm.af2_to_esm.data
transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk)
......@@ -261,9 +284,6 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.bias = esm.lm_head.bias
# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens.
......@@ -271,6 +291,18 @@ def convert_esm_checkpoint_to_pytorch(
else:
sample_data = SAMPLE_DATA
if is_folding_model:
hf_tokenizer = get_esmfold_tokenizer()
hf_tokens = hf_tokenizer(
[row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
)
esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
hf_tokens["attention_mask"] == esmfold_mask
)
else:
# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
# Prepare tokenizer and make sure it matches
with TemporaryDirectory() as tempdir:
......@@ -281,6 +313,7 @@ def convert_esm_checkpoint_to_pytorch(
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success:
raise Exception("Tokenization does not match!")
......@@ -292,10 +325,10 @@ def convert_esm_checkpoint_to_pytorch(
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
# original and the converted model on the GPU at the same time.
their_output = esm.cuda().infer([row[1] for row in sample_data])
our_output = model.cuda()(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
else:
our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"]
......@@ -322,23 +355,7 @@ def convert_esm_checkpoint_to_pytorch(
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda()
reloaded_output = reloaded(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6)
else:
max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item()
success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6)
print(f"max_absolute_diff = {max_absolute_diff}")
print("Does the model output the same tensors after reloading?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
del esm # Free up some memory before continuing
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
......
......@@ -2100,7 +2100,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
>>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt") # A tiny random peptide
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
>>> outputs = model(**inputs)
>>> folded_positions = outputs.positions
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment