"tests/vscode:/vscode.git/clone" did not exist on "a72f1c9f5b907f96cbb7de3bbb02a1d431d34071"
Unverified Commit 7f9b7b3f authored by Matt's avatar Matt Committed by GitHub
Browse files

Add ESMFold (#19977)



* initial commit

* First draft that gets outputs without crashing!

* Add all the ported openfold dependencies

* testing

* Restructure config files for ESMFold

* Debugging to find output discrepancies

* Mainly style

* Make model runnable without extra deps

* Remove utils and merge them to the modeling file

* Use correct gelu and remove some debug prints

* More cleanup

* Update esm docs

* Update conversion script to support ESMFold properly

* Port some top-level changes from ESMFold repo

* Expand EsmFold docstrings

* Make attention_mask optional (default to all 1s)

* Add inference test for ESMFold

* Use config and not n kwargs

* Add modeling output class

* Remove einops

* Remove chunking in ESM FFN

* Update tests for ESMFold

* Quality

* REpo consistency

* Remove tree dependency from ESMFold

* make fixup

* Add an error in case my structure map function breaks later

* Remove needless code

* Stop auto-casting the LM to float16 so CPU tests pass

* Stop auto-casting the LM to float16 so CPU tests pass

* Final test updates

* Split test file

* Copyright and quality

* Unpin PyTorch to see built doc

* Fix config file to_dict() method

* Add some docstrings to the output

* Skip TF checkpoint tests for ESM until we reupload those

* make fixup

* More docstrings

* Unpin to get even with main

* Flag example to write
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent 4c9e0f02
...@@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License. ...@@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License.
## Overview ## Overview
This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental
AI Research Team, providing the state-of-the-art ESM-2, and the previously released ESM-1b and ESM-1v. Transformer AI Research Team, providing the state-of-the-art ESMFold and ESM-2, and the previously released ESM-1b and ESM-1v.
protein language models were introduced in the paper [Biological structure and function emerge from scaling Transformer protein language models were introduced in the paper [Biological structure and function emerge from scaling
unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by
Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott,
C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. C. Lawrence Zitnick, Jerry Ma, and Rob Fergus.
...@@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal ...@@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal
structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie,
Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives. Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives.
Also introduced in this paper was ESMFold. It uses an ESM-2 stem with a head that can predict folded protein
structures with state-of-the-art accuracy. Unlike [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2),
it relies on the token embeddings from the large pre-trained protein language model stem and does not perform a multiple
sequence alignment (MSA) step at inference time, which means that ESMFold checkpoints are fully "standalone" -
they do not require a database of known protein sequences and structures with associated external query tools
to make predictions, and are much faster as a result.
The abstract from The abstract from
"Biological structure and function emerge from scaling unsupervised learning to 250 "Biological structure and function emerge from scaling unsupervised learning to 250
...@@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura ...@@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura
proteins in practical timescales.* proteins in practical timescales.*
Tips: Tips:
- ESM models are trained with a masked language modeling (MLM) objective. - ESM models are trained with a masked language modeling (MLM) objective.
The original code can be found [here](https://github.com/facebookresearch/esm) and was The original code can be found [here](https://github.com/facebookresearch/esm) and was
was developed by the Fundamental AI Research team at Meta AI. was developed by the Fundamental AI Research team at Meta AI.
This model was contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu) ESM-1b, ESM-1v and ESM-2 were contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
and [Matt](https://huggingface.co/Rocketknight1). and [Matt](https://huggingface.co/Rocketknight1).
ESMFold was contributed to huggingface by [Matt](https://huggingface.co/Rocketknight1) and
[Sylvain](https://huggingface.co/sgugger), with a big thank you to Nikita Smetanin, Roshan Rao and Tom Sercu for their
help throughout the process!
The HuggingFace port of ESMFold uses portions of the [openfold](https://github.com/aqlaboratory/openfold) library.
The `openfold` library is licensed under the Apache License 2.0.
## EsmConfig ## EsmConfig
[[autodoc]] EsmConfig [[autodoc]] EsmConfig
...@@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1). ...@@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1).
[[autodoc]] EsmForTokenClassification [[autodoc]] EsmForTokenClassification
- forward - forward
## EsmForProteinFolding
[[autodoc]] EsmForProteinFolding
- forward
## TFEsmModel ## TFEsmModel
[[autodoc]] TFEsmModel [[autodoc]] TFEsmModel
......
...@@ -1265,7 +1265,9 @@ else: ...@@ -1265,7 +1265,9 @@ else:
_import_structure["models.esm"].extend( _import_structure["models.esm"].extend(
[ [
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST", "ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"EsmFoldPreTrainedModel",
"EsmForMaskedLM", "EsmForMaskedLM",
"EsmForProteinFolding",
"EsmForSequenceClassification", "EsmForSequenceClassification",
"EsmForTokenClassification", "EsmForTokenClassification",
"EsmModel", "EsmModel",
...@@ -4144,7 +4146,9 @@ if TYPE_CHECKING: ...@@ -4144,7 +4146,9 @@ if TYPE_CHECKING:
) )
from .models.esm import ( from .models.esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST, ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
EsmFoldPreTrainedModel,
EsmForMaskedLM, EsmForMaskedLM,
EsmForProteinFolding,
EsmForSequenceClassification, EsmForSequenceClassification,
EsmForTokenClassification, EsmForTokenClassification,
EsmModel, EsmModel,
......
...@@ -39,6 +39,7 @@ else: ...@@ -39,6 +39,7 @@ else:
"EsmModel", "EsmModel",
"EsmPreTrainedModel", "EsmPreTrainedModel",
] ]
_import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
try: try:
if not is_tf_available(): if not is_tf_available():
...@@ -55,7 +56,6 @@ else: ...@@ -55,7 +56,6 @@ else:
"TFEsmPreTrainedModel", "TFEsmPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
from .tokenization_esm import EsmTokenizer from .tokenization_esm import EsmTokenizer
...@@ -74,6 +74,7 @@ if TYPE_CHECKING: ...@@ -74,6 +74,7 @@ if TYPE_CHECKING:
EsmModel, EsmModel,
EsmPreTrainedModel, EsmPreTrainedModel,
) )
from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding
try: try:
if not is_tf_available(): if not is_tf_available():
......
# coding=utf-8 # coding=utf-8
# Copyright 2021 Facebook and The HuggingFace Inc. team. All rights reserved. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
# limitations under the License. # limitations under the License.
""" ESM model configuration""" """ ESM model configuration"""
from dataclasses import asdict, dataclass
from typing import Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# TODO Update this
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = { ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json", "facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
# See all ESM models at https://huggingface.co/models?filter=esm # See all ESM models at https://huggingface.co/models?filter=esm
...@@ -118,9 +122,12 @@ class EsmConfig(PretrainedConfig): ...@@ -118,9 +122,12 @@ class EsmConfig(PretrainedConfig):
classifier_dropout=None, classifier_dropout=None,
emb_layer_norm_before=None, emb_layer_norm_before=None,
token_dropout=False, token_dropout=False,
is_folding_model=False,
esmfold_config=None,
vocab_list=None,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -138,5 +145,225 @@ class EsmConfig(PretrainedConfig): ...@@ -138,5 +145,225 @@ class EsmConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.emb_layer_norm_before = emb_layer_norm_before self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout self.token_dropout = token_dropout
self.mask_token_id = mask_token_id self.is_folding_model = is_folding_model
self.pad_token_id = pad_token_id if is_folding_model:
if esmfold_config is None:
logger.info("No esmfold_config supplied for folding model, using default values.")
esmfold_config = EsmFoldConfig()
elif isinstance(esmfold_config, dict):
esmfold_config = EsmFoldConfig(**esmfold_config)
self.esmfold_config = esmfold_config
if vocab_list is None:
logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
self.vocab_list = get_default_vocab_list()
else:
self.vocab_list = vocab_list
else:
self.esmfold_config = None
self.vocab_list = None
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
if isinstance(self.esmfold_config, EsmFoldConfig):
output["esmfold_config"] = self.esmfold_config.to_dict()
return output
@dataclass
class EsmFoldConfig:
esm_type: str = None
fp16_esm: bool = True
use_esm_attn_map: bool = False
esm_ablate_pairwise: bool = False
esm_ablate_sequence: bool = False
esm_input_dropout: float = 0
embed_aa: bool = True
bypass_lm: bool = False
lddt_head_hid_dim: int = 128
trunk: "TrunkConfig" = None
def __post_init__(self):
if self.trunk is None:
self.trunk = TrunkConfig()
elif isinstance(self.trunk, dict):
self.trunk = TrunkConfig(**self.trunk)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["trunk"] = self.trunk.to_dict()
return output
@dataclass
class TrunkConfig:
num_blocks: int = 48
sequence_state_dim: int = 1024
pairwise_state_dim: int = 128
sequence_head_width: int = 32
pairwise_head_width: int = 32
position_bins: int = 32
dropout: float = 0
layer_drop: float = 0
cpu_grad_checkpoint: bool = False
max_recycles: int = 4
chunk_size: Optional[int] = 128
structure_module: "StructureModuleConfig" = None
def __post_init__(self):
if self.structure_module is None:
self.structure_module = StructureModuleConfig()
elif isinstance(self.structure_module, dict):
self.structure_module = StructureModuleConfig(**self.structure_module)
if self.max_recycles <= 0:
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
if self.sequence_state_dim % self.sequence_state_dim != 0:
raise ValueError(
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
)
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
raise ValueError(
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
)
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
raise ValueError(
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
)
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
raise ValueError(
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
)
if self.pairwise_state_dim % 2 != 0:
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
if self.dropout >= 0.4:
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["structure_module"] = self.structure_module.to_dict()
return output
@dataclass
class StructureModuleConfig:
"""
Args:
sequence_dim:
Single representation channel dimension
pairwise_dim:
Pair representation channel dimension
ipa_dim:
IPA hidden channel dimension
resnet_dim:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
num_heads_ipa:
Number of IPA heads
num_qk_points:
Number of query/key points to generate during IPA
num_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
num_blocks:
Number of structure module blocks
num_transition_layers:
Number of layers in the single representation transition (Alg. 23 lines 8-9)
num_resnet_blocks:
Number of blocks in the angle resnet
num_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""
sequence_dim: int = 384
pairwise_dim: int = 128
ipa_dim: int = 16
resnet_dim: int = 128
num_heads_ipa: int = 12
num_qk_points: int = 4
num_v_points: int = 8
dropout_rate: float = 0.1
num_blocks: int = 8
num_transition_layers: int = 1
num_resnet_blocks: int = 2
num_angles: int = 7
trans_scale_factor: int = 10
epsilon: float = 1e-8
inf: float = 1e5
def to_dict(self):
return asdict(self)
def get_default_vocab_list():
return (
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
"<null_1>",
"<mask>",
)
...@@ -23,7 +23,8 @@ from tempfile import TemporaryDirectory ...@@ -23,7 +23,8 @@ from tempfile import TemporaryDirectory
import torch import torch
import esm as esm_module import esm as esm_module
from transformers.models.esm.configuration_esm import EsmConfig from esm.esmfold.v1.pretrained import esmfold_v1
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
from transformers.models.esm.modeling_esm import ( from transformers.models.esm.modeling_esm import (
EsmForMaskedLM, EsmForMaskedLM,
EsmForSequenceClassification, EsmForSequenceClassification,
...@@ -33,6 +34,7 @@ from transformers.models.esm.modeling_esm import ( ...@@ -33,6 +34,7 @@ from transformers.models.esm.modeling_esm import (
EsmSelfAttention, EsmSelfAttention,
EsmSelfOutput, EsmSelfOutput,
) )
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
from transformers.models.esm.tokenization_esm import EsmTokenizer from transformers.models.esm.tokenization_esm import EsmTokenizer
from transformers.utils import logging from transformers.utils import logging
...@@ -60,19 +62,51 @@ MODEL_MAPPING = { ...@@ -60,19 +62,51 @@ MODEL_MAPPING = {
"esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D, "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
"esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D, "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
"esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D, "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
"esmfold_v1": esmfold_v1,
} }
def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict())
if status.missing_keys:
raise ValueError(f"Missing keys: {status.missing_keys}")
if status.unexpected_keys:
raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
def convert_esm_checkpoint_to_pytorch( def convert_esm_checkpoint_to_pytorch(
model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
): ):
""" """
Copy/paste/tweak esm's weights to our BERT structure. Copy/paste/tweak esm's weights to our BERT structure.
""" """
if model.startswith("esmfold"):
esm = MODEL_MAPPING[model]()
alphabet = esm.esm.alphabet
else:
esm, alphabet = MODEL_MAPPING[model]() esm, alphabet = MODEL_MAPPING[model]()
esm.eval() # disable dropout esm.eval() # disable dropout
esm_sent_encoder = esm
if hasattr(esm, "args"): if model.startswith("esmfold"):
embed_dim = esm.esm.embed_dim
num_layers = esm.esm.num_layers
num_attention_heads = esm.esm.attention_heads
intermediate_size = 4 * embed_dim
token_dropout = esm.esm.token_dropout
emb_layer_norm_before = False # This code path does not exist in ESM-2
position_embedding_type = "rotary"
is_folding_model = True
esmfold_config = EsmFoldConfig()
for key, val in esm.cfg.items():
if hasattr(esmfold_config, key) and key != "trunk":
setattr(esmfold_config, key, val)
for key, val in esm.cfg.trunk.items():
if hasattr(esmfold_config.trunk, key) and key != "structure_module":
setattr(esmfold_config.trunk, key, val)
for key, val in esm.cfg.trunk.structure_module.items():
if hasattr(esmfold_config.trunk.structure_module, key):
setattr(esmfold_config.trunk.structure_module, key, val)
elif hasattr(esm, "args"):
# Indicates an ESM-1b or ESM-1v model # Indicates an ESM-1b or ESM-1v model
embed_dim = esm.args.embed_dim embed_dim = esm.args.embed_dim
num_layers = esm.args.layers num_layers = esm.args.layers
...@@ -81,6 +115,8 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -81,6 +115,8 @@ def convert_esm_checkpoint_to_pytorch(
token_dropout = esm.args.token_dropout token_dropout = esm.args.token_dropout
emb_layer_norm_before = True if esm.emb_layer_norm_before else False emb_layer_norm_before = True if esm.emb_layer_norm_before else False
position_embedding_type = "absolute" position_embedding_type = "absolute"
is_folding_model = False
esmfold_config = None
else: else:
# Indicates an ESM-2 model # Indicates an ESM-2 model
embed_dim = esm.embed_dim embed_dim = esm.embed_dim
...@@ -90,9 +126,18 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -90,9 +126,18 @@ def convert_esm_checkpoint_to_pytorch(
token_dropout = esm.token_dropout token_dropout = esm.token_dropout
emb_layer_norm_before = False # This code path does not exist in ESM-2 emb_layer_norm_before = False # This code path does not exist in ESM-2
position_embedding_type = "rotary" position_embedding_type = "rotary"
is_folding_model = False
esmfold_config = None
vocab_list = tuple(alphabet.all_toks)
if is_folding_model:
original_esm_model = esm.esm
else:
original_esm_model = esm
config = EsmConfig( config = EsmConfig(
vocab_size=esm_sent_encoder.embed_tokens.num_embeddings, vocab_size=original_esm_model.embed_tokens.num_embeddings,
mask_token_id=alphabet.mask_idx, mask_token_id=alphabet.mask_idx,
hidden_size=embed_dim, hidden_size=embed_dim,
num_hidden_layers=num_layers, num_hidden_layers=num_layers,
...@@ -102,36 +147,45 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -102,36 +147,45 @@ def convert_esm_checkpoint_to_pytorch(
layer_norm_eps=1e-5, # PyTorch default used in fairseq layer_norm_eps=1e-5, # PyTorch default used in fairseq
attention_probs_dropout_prob=0.0, attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0, hidden_dropout_prob=0.0,
pad_token_id=esm.padding_idx, pad_token_id=alphabet.padding_idx,
emb_layer_norm_before=emb_layer_norm_before, emb_layer_norm_before=emb_layer_norm_before,
token_dropout=token_dropout, token_dropout=token_dropout,
position_embedding_type=position_embedding_type, position_embedding_type=position_embedding_type,
is_folding_model=is_folding_model,
esmfold_config=esmfold_config,
vocab_list=vocab_list,
) )
if classification_head: if classification_head:
config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0] config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our BERT config:", config) print("Our ESM config:", config)
model = EsmForSequenceClassification(config) if classification_head else EsmForMaskedLM(config) if model.startswith("esmfold"):
model_class = EsmForProteinFolding
elif classification_head:
model_class = EsmForSequenceClassification
else:
model_class = EsmForMaskedLM
model = model_class(config)
model.eval() model.eval()
# Now let's copy all the weights. # Now let's copy all the weights.
# Embeddings # Embeddings
model.esm.embeddings.word_embeddings.weight = esm_sent_encoder.embed_tokens.weight model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
if position_embedding_type == "absolute": if position_embedding_type == "absolute":
model.esm.embeddings.position_embeddings.weight = esm_sent_encoder.embed_positions.weight model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
if config.emb_layer_norm_before: if config.emb_layer_norm_before:
model.esm.embeddings.layer_norm.weight = esm_sent_encoder.emb_layer_norm_before.weight model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
model.esm.embeddings.layer_norm.bias = esm_sent_encoder.emb_layer_norm_before.bias model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
model.esm.encoder.emb_layer_norm_after.weight = esm_sent_encoder.emb_layer_norm_after.weight model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
model.esm.encoder.emb_layer_norm_after.bias = esm_sent_encoder.emb_layer_norm_after.bias model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
# Encoder: start of layer # Encoder: start of layer
layer: EsmLayer = model.esm.encoder.layer[i] layer: EsmLayer = model.esm.encoder.layer[i]
# esm_layer: TransformerSentenceEncoderLayer = esm_sent_encoder.layers[i] # esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i]
esm_layer = esm_sent_encoder.layers[i] esm_layer = original_esm_model.layers[i]
# self attention # self attention
self_attn: EsmSelfAttention = layer.attention.self self_attn: EsmSelfAttention = layer.attention.self
...@@ -183,7 +237,17 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -183,7 +237,17 @@ def convert_esm_checkpoint_to_pytorch(
bert_output.dense.bias = esm_layer.fc2.bias bert_output.dense.bias = esm_layer.fc2.bias
# end of layer # end of layer
if classification_head: if is_folding_model:
model.esm_s_combine.data = esm.esm_s_combine.data
transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk)
transfer_and_check_weights(esm.distogram_head, model.distogram_head)
transfer_and_check_weights(esm.ptm_head, model.ptm_head)
transfer_and_check_weights(esm.lm_head, model.lm_head)
transfer_and_check_weights(esm.lddt_head, model.lddt_head)
elif classification_head:
model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
...@@ -195,15 +259,19 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -195,15 +259,19 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
model.lm_head.decoder.weight = esm.lm_head.weight model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.decoder.bias = esm.lm_head.bias model.lm_head.bias = esm.lm_head.bias
# Let's check that we get the same results. # Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter() batch_converter = alphabet.get_batch_converter()
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens.
sample_data = SAMPLE_DATA[:2]
else:
sample_data = SAMPLE_DATA
batch_labels, batch_strs, batch_tokens = batch_converter(SAMPLE_DATA) batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
# Prepare tokenizer and make sure it matches # Prepare tokenizer and make sure it matches
with TemporaryDirectory() as tempdir: with TemporaryDirectory() as tempdir:
vocab = "\n".join(alphabet.all_toks) vocab = "\n".join(alphabet.all_toks)
...@@ -211,24 +279,40 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -211,24 +279,40 @@ def convert_esm_checkpoint_to_pytorch(
vocab_file.write_text(vocab) vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file)) hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokens = hf_tokenizer([row[1] for row in SAMPLE_DATA], return_tensors="pt", padding=True) hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens) success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩") print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success: if not success:
raise Exception("Tokenization does not match!") raise Exception("Tokenization does not match!")
with torch.no_grad(): with torch.no_grad():
if is_folding_model:
# Let's test the model in parts
# ESMFold always converts the ESM stem to float16, which requires float16 ops
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
# original and the converted model on the GPU at the same time.
our_output = model.cuda()(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
else:
our_output = model(**hf_tokens, output_hidden_states=True) our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"] our_output = our_output["logits"]
if classification_head: if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens)) their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
else: else:
their_output = esm(batch_tokens, repr_layers=list(range(999))) their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
their_output = their_output["logits"] their_output = their_output["logits"]
print(our_output.shape, their_output.shape)
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
else:
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
success = torch.allclose(our_output, their_output, atol=1e-5)
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5 print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
success = torch.allclose(our_output, their_output, atol=3e-4)
print("Do both models output the same tensors?", "🔥" if success else "💩") print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success: if not success:
...@@ -238,6 +322,24 @@ def convert_esm_checkpoint_to_pytorch( ...@@ -238,6 +322,24 @@ def convert_esm_checkpoint_to_pytorch(
print(f"Saving model to {pytorch_dump_folder_path}") print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda()
reloaded_output = reloaded(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6)
else:
max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item()
success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6)
print(f"max_absolute_diff = {max_absolute_diff}")
print("Does the model output the same tensors after reloading?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
print(f"Saving tokenizer to {pytorch_dump_folder_path}") print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path) hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch ESM model.""" """ PyTorch ESM model."""
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -21,7 +22,6 @@ import torch.utils.checkpoint ...@@ -21,7 +22,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -30,12 +30,7 @@ from ...modeling_outputs import ( ...@@ -30,12 +30,7 @@ from ...modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import ( from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import logging from ...utils import logging
from .configuration_esm import EsmConfig from .configuration_esm import EsmConfig
...@@ -66,6 +61,13 @@ def apply_rotary_pos_emb(x, cos, sin): ...@@ -66,6 +61,13 @@ def apply_rotary_pos_emb(x, cos, sin):
return (x * cos) + (rotate_half(x) * sin) return (x * cos) + (rotate_half(x) * sin)
def gelu(x):
"""
This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
""" """
Rotary position embeddings based on those in Rotary position embeddings based on those in
...@@ -163,7 +165,9 @@ class EsmEmbeddings(nn.Module): ...@@ -163,7 +165,9 @@ class EsmEmbeddings(nn.Module):
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = attention_mask.sum(-1) src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
embeddings.dtype
)
if self.position_embedding_type == "absolute": if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
...@@ -172,7 +176,7 @@ class EsmEmbeddings(nn.Module): ...@@ -172,7 +176,7 @@ class EsmEmbeddings(nn.Module):
if self.layer_norm is not None: if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings) embeddings = self.layer_norm(embeddings)
if attention_mask is not None: if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
# Matt: I think this line was copied incorrectly from BERT, disabling it for now. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
# embeddings = self.dropout(embeddings) # embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -398,19 +402,14 @@ class EsmAttention(nn.Module): ...@@ -398,19 +402,14 @@ class EsmAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
class EsmIntermediate(nn.Module): class EsmIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = gelu(hidden_states)
return hidden_states return hidden_states
...@@ -497,15 +496,13 @@ class EsmLayer(nn.Module): ...@@ -497,15 +496,13 @@ class EsmLayer(nn.Module):
cross_attn_present_key_value = cross_attention_outputs[-1] cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = self.feed_forward_chunk(attention_output)
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output # if decoder, return the attn key/values as the last output
if self.is_decoder: if self.is_decoder:
outputs = outputs + (present_key_value,) outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
......
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from ...deepspeed import is_deepspeed_available
from ...modeling_outputs import ModelOutput
from ...utils import (
ContextManagers,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
logging,
replace_return_docstrings,
)
from .configuration_esm import EsmConfig
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
from .openfold_utils import (
OFProtein,
Rigid,
Rotation,
atom14_to_atom37,
chunk_layer,
compute_predicted_aligned_error,
compute_tm,
frames_and_literature_positions_to_atom14_pos,
make_atom14_masks,
residue_constants,
to_pdb,
torsion_angles_to_frames,
)
logger = logging.get_logger(__name__)
@dataclass
class EsmForProteinFoldingOutput(ModelOutput):
"""
Output type of [`EsmForProteinFoldingOutput`].
Args:
frames (`torch.FloatTensor`):
Output frames.
sidechain_frames (`torch.FloatTensor`):
Output sidechain frames.
unnormalized_angles (`torch.FloatTensor`):
Predicted unnormalized backbone and side chain torsion angles.
angles (`torch.FloatTensor`):
Predicted backbone and side chain torsion angles.
positions (`torch.FloatTensor`):
Predicted positions of the backbone and side chain atoms.
states (`torch.FloatTensor`):
Hidden states from the protein folding trunk.
s_s (`torch.FloatTensor`):
Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
s_z (`torch.FloatTensor`):
Pairwise residue embeddings.
distogram_logits (`torch.FloatTensor`):
Input logits to the distogram used to compute residue distances.
lm_logits (`torch.FloatTensor`):
Logits output by the ESM-2 protein language model stem.
aatype (`torch.FloatTensor`):
Input amino acids (AlphaFold2 indices).
atom14_atom_exists (`torch.FloatTensor`):
Whether each atom exists in the atom14 representation.
residx_atom14_to_atom37 (`torch.FloatTensor`):
Mapping between atoms in the atom14 and atom37 representations.
residx_atom37_to_atom14 (`torch.FloatTensor`):
Mapping between atoms in the atom37 and atom14 representations.
atom37_atom_exists (`torch.FloatTensor`):
Whether each atom exists in the atom37 representation.
residue_index (`torch.FloatTensor`):
The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
a sequence of integers from 0 to `sequence_length`.
lddt_head (`torch.FloatTensor`):
Raw outputs from the lddt head used to compute plddt.
plddt (`torch.FloatTensor`):
Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
uncertain, or where the protein structure is disordered.
ptm_logits (`torch.FloatTensor`):
Raw logits used for computing ptm.
ptm (`torch.FloatTensor`):
TM-score output representing the model's high-level confidence in the overall structure.
aligned_confidence_probs (`torch.FloatTensor`):
Per-residue confidence scores for the aligned structure.
predicted_aligned_error (`torch.FloatTensor`):
Predicted error between the model's prediction and the ground truth.
max_predicted_aligned_error (`torch.FloatTensor`):
Per-sample maximum predicted error.
"""
frames: torch.FloatTensor = None
sidechain_frames: torch.FloatTensor = None
unnormalized_angles: torch.FloatTensor = None
angles: torch.FloatTensor = None
positions: torch.FloatTensor = None
states: torch.FloatTensor = None
s_s: torch.FloatTensor = None
s_z: torch.FloatTensor = None
distogram_logits: torch.FloatTensor = None
lm_logits: torch.FloatTensor = None
aatype: torch.FloatTensor = None
atom14_atom_exists: torch.FloatTensor = None
residx_atom14_to_atom37: torch.FloatTensor = None
residx_atom37_to_atom14: torch.FloatTensor = None
atom37_atom_exists: torch.FloatTensor = None
residue_index: torch.FloatTensor = None
lddt_head: torch.FloatTensor = None
plddt: torch.FloatTensor = None
ptm_logits: torch.FloatTensor = None
ptm: torch.FloatTensor = None
aligned_confidence_probs: torch.FloatTensor = None
predicted_aligned_error: torch.FloatTensor = None
max_predicted_aligned_error: torch.FloatTensor = None
ESMFOLD_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`EsmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
num_recycles (`int`, *optional*, defaults to `None`):
Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
consists of passing the output of the folding trunk back in as input to the trunk. During training, the
number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
after each recycle. During inference, num_recycles should be set to the highest value that the model was
trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
used.
"""
def is_fp16_enabled():
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
def is_deepspeed_initialized():
if is_deepspeed_available():
return False
else:
try:
import deepspeed
# This is not available in all DeepSpeed versions.
return deepspeed.utils.is_initialized()
except Exception:
return False
def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
"""
Takes a list of tensors with the following dimensions:
[(d_11, ..., d_1K),
(d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
and stack + pads them into a single tensor of:
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
"""
if len(samples) == 0:
return torch.Tensor()
if len(set(x.dim() for x in samples)) != 1:
raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
(device,) = tuple(set(x.device for x in samples)) # assumes all on same device
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
result.fill_(pad_v)
for i in range(len(samples)):
result_i = result[i]
t = samples[i]
result_i[tuple(slice(0, k) for k in t.shape)] = t
return result
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
scale = scale / max(1, shape[1])
if not is_scipy_available():
logger.warning(
"This init requires scipy, but scipy was not found, default to an approximation that might not be"
" equivalent."
)
std = math.sqrt(scale)
torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
else:
from scipy.stats import truncnorm
std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
samples = np.reshape(samples, shape)
weights.copy_(torch.tensor(samples, device=weights.device))
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
class EsmFoldLinear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found in the code.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs. Overrides init if not None.
"""
super().__init__(in_dim, out_dim, bias=bias)
if bias:
with torch.no_grad():
self.bias.fill_(0)
self.init = init
self.init_fn = init_fn
if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
raise ValueError("Invalid init string.")
class EsmFoldLayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super().__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if d is torch.bfloat16 and not is_deepspeed_initialized():
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
else:
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
return out
@torch.jit.ignore
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of type bfloat16
"""
d = t.dtype
if d is torch.bfloat16 and not is_deepspeed_initialized():
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
class EsmFoldAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
self.linear_g = None
if self.gating:
self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid()
def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = 1024,
lma_kv_chunk_size: int = 4096,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
stock PyTorch implementation is used instead
lma_q_chunk_size:
Query chunk size (for LMA)
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
if use_flash and biases is not None:
raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError("Choose at most one alternative attention algorithm")
if biases is None:
biases = []
# [*, H, Q/K, C_hidden]
query, key, value = self._prep_qkv(q_x, kv_x)
key = permute_final_dims(key, (1, 0))
# [*, H, Q, K]
output = torch.matmul(query, key)
for b in biases:
output += b
output = softmax_no_cast(output, -1)
# [*, H, Q, C_hidden]
output = torch.matmul(output, value)
output = output.transpose(-2, -3)
output = self._wrap_up(output, q_x)
return output
class EsmFoldTriangleAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super().__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.inf = inf
self.layer_norm = LayerNorm(self.c_in)
self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
)
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(
q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
)
if not self.starting:
x = x.transpose(-2, -3)
return x
class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, config, _outgoing=True):
super().__init__()
c_hidden = config.pairwise_state_dim
self._outgoing = _outgoing
self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
self.layer_norm_in = LayerNorm(c_hidden)
self.layer_norm_out = LayerNorm(c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(
self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if _inplace_chunk_size is not None:
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
a_chunk,
b_chunk,
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
def _inference_forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
Useful for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the default forward implementation below.
Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
"square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
variables.
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask, a=True):
if a:
linear_g = self.linear_a_g
linear_p = self.linear_a_p
else:
linear_g = self.linear_b_g
linear_p = self.linear_b_p
pair = self.layer_norm_in(pair)
p = linear_g(pair)
p.sigmoid_()
p *= linear_p(pair)
p *= mask
p = permute_final_dims(p, (2, 0, 1))
return p
def compute_projection(pair, mask, a=True, chunked=True):
need_transpose = self._outgoing ^ a
if not chunked:
p = compute_projection_helper(pair, mask, a)
if need_transpose:
p = p.transpose(-1, -2)
else:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g = self.linear_a_g if a else self.linear_b_g
c = linear_g.bias.shape[-1]
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
p = pair.new_zeros(out_shape)
for i in range(0, pair.shape[-3], inplace_chunk_size):
pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
pair_chunk = compute_projection_helper(
pair[..., i : i + inplace_chunk_size, :, :],
mask[..., i : i + inplace_chunk_size, :, :],
a,
)
if need_transpose:
pair_chunk = pair_chunk.transpose(-1, -2)
p[..., i : i + inplace_chunk_size] = pair_chunk
else:
p[..., i : i + inplace_chunk_size, :] = pair_chunk
del pair_chunk
return p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a = compute_projection(z, mask, True, chunked=True)
if inplace_chunk_size is not None:
n = a.shape[-1]
half_n = n // 2 + n % 2
row_dim = -3
col_dim = -2
b_chunk_dim = row_dim if self._outgoing else col_dim
def empty_slicer(t):
return [slice(None) for _ in t.shape]
def slice_tensor(t, start, end, dim):
# Slices start:end from the dim dimension of t
s = empty_slicer(t)
s[dim] = slice(start, end)
return t[s]
def flip_z_cache_(z_cache, z):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
z_cache = z_cache.transpose(row_dim, col_dim)
# If n is odd, we need to shrink the z_cache by one row
z_cache = z_cache[..., : (n // 2), :, :]
# Move the 3rd quadrant of z into the
first_half_slicer = empty_slicer(z_cache)
first_half_slicer[col_dim] = slice(0, half_n)
z_cache[first_half_slicer] = quadrant_3
# Get the fourth quadrant of z
quadrant_4 = slice_tensor(z, half_n, None, row_dim)
quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer = empty_slicer(z_cache)
quadrant_3_slicer[col_dim] = slice(half_n, None)
z_cache[quadrant_3_slicer] = quadrant_4
return z_cache
# Initialize the z cache to the left half of z.
z_cache_shape = list(z.shape)
z_cache_shape[col_dim] = half_n
z_cache = z.new_zeros(z_cache_shape)
z_cache_slicer = empty_slicer(z_cache)
z_cache_slicer[col_dim] = slice(0, half_n)
z_cache.copy_(z[z_cache_slicer])
z_cache_rotated = False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range = list(range(0, half_n, inplace_chunk_size))
initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
after_half = list(range(half_n, n, inplace_chunk_size))
after_half_offsets = [inplace_chunk_size for _ in after_half]
combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
for i, offset in combined_range_with_offsets:
if not z_cache_rotated and i >= half_n:
z_cache = flip_z_cache_(z_cache, z)
z_cache_rotated = True
z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
z_chunk_b = z_chunk_b.clone()
if b_chunk_dim == col_dim:
z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
else: # b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if not z_cache_rotated:
z_chunk_slicer = empty_slicer(z_chunk_b)
z_chunk_slicer[col_dim] = slice(0, half_n)
z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
else:
z_cache_offset = i - half_n
z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
del z_chunk_b
x_chunk = torch.matmul(a, b_chunk)
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
g_chunk.sigmoid_()
del z_chunk_g
x_chunk *= g_chunk
# Write the columns into z in-place
z_slicer = empty_slicer(z)
z_slicer[col_dim] = slice(i, i + offset)
if with_add:
z[z_slicer] += x_chunk
else:
z[z_slicer] = x_chunk
else:
b = compute_projection(z, mask, False, False)
x = torch.matmul(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if with_add:
z += x
else:
z = x
return z
def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if inplace_safe:
x = self._inference_forward(
z,
mask,
inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = mask
a = a * self.sigmoid(self.linear_a_g(z))
a = a * self.linear_a_p(z)
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class EsmFoldPreTrainedModel(EsmPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# Subclass `EsMPreTrainedModel` to deal with special init
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, EsmFoldLinear):
with torch.no_grad():
if module.init_fn is not None:
module.init_fn(module.weight, module.bias)
elif module.init == "default":
trunc_normal_init_(module.weight, scale=1.0)
elif module.init == "relu":
trunc_normal_init_(module.weight, scale=2.0)
elif module.init == "glorot":
nn.init.xavier_uniform_(module.weight, gain=1)
elif module.init == "gating":
module.weight.fill_(0.0)
if module.bias:
module.bias.fill_(1.0)
elif module.init == "normal":
torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
elif module.init == "final":
module.weight.fill_(0.0)
elif isinstance(module, EsmFoldInvariantPointAttention):
ipa_point_weights_init_(module.head_weights)
elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
else:
super()._init_weights(module)
class EsmFoldSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, head_width, gated=False):
super().__init__()
assert embed_dim == num_heads * head_width
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_width = head_width
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.gated = gated
if gated:
self.g_proj = nn.Linear(embed_dim, embed_dim)
torch.nn.init.zeros_(self.g_proj.weight)
torch.nn.init.ones_(self.g_proj.bias)
self.rescale_factor = self.head_width**-0.5
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, x, mask=None, bias=None, indices=None):
"""
Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
use mask.
Inputs:
x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
Outputs:
sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
"""
t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
t = t.permute(0, 2, 1, 3)
q, k, v = t.chunk(3, dim=-1)
q = self.rescale_factor * q
a = torch.einsum("...qc,...kc->...qk", q, k)
# Add external attention bias.
if bias is not None:
a = a + bias.permute(0, 3, 1, 2)
# Do not attend to padding tokens.
if mask is not None:
mask = mask[:, None, None]
a = a.masked_fill(mask == False, -np.inf) # noqa: E712
a = nn.functional.softmax(a, dim=-1)
y = torch.einsum("...hqk,...hkc->...qhc", a, v)
y = y.reshape(*y.shape[:2], -1)
if self.gated:
y = self.g_proj(x).sigmoid() * y
y = self.o_proj(y)
return y, a.permute(0, 3, 1, 2)
class EsmFoldDropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask along a particular dimension.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
super().__init__()
self.r = r
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
return x * self.dropout(x.new_ones(shape))
class EsmFoldSequenceToPair(nn.Module):
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
super().__init__()
self.layernorm = nn.LayerNorm(sequence_state_dim)
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
torch.nn.init.zeros_(self.proj.bias)
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, sequence_state):
"""
Inputs:
sequence_state: B x L x sequence_state_dim
Output:
pairwise_state: B x L x L x pairwise_state_dim
Intermediate state:
B x L x L x 2*inner_dim
"""
assert len(sequence_state.shape) == 3
s = self.layernorm(sequence_state)
s = self.proj(s)
q, k = s.chunk(2, dim=-1)
prod = q[:, None, :, :] * k[:, :, None, :]
diff = q[:, None, :, :] - k[:, :, None, :]
x = torch.cat([prod, diff], dim=-1)
x = self.o_proj(x)
return x
class EsmFoldPairToSequence(nn.Module):
def __init__(self, pairwise_state_dim, num_heads):
super().__init__()
self.layernorm = nn.LayerNorm(pairwise_state_dim)
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
def forward(self, pairwise_state):
"""
Inputs:
pairwise_state: B x L x L x pairwise_state_dim
Output:
pairwise_bias: B x L x L x num_heads
"""
assert len(pairwise_state.shape) == 4
z = self.layernorm(pairwise_state)
pairwise_bias = self.linear(z)
return pairwise_bias
class EsmFoldResidueMLP(nn.Module):
def __init__(self, embed_dim, inner_dim, dropout=0):
super().__init__()
self.mlp = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, inner_dim),
nn.ReLU(),
nn.Linear(inner_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
return x + self.mlp(x)
class EsmFoldTriangularSelfAttentionBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
sequence_state_dim = config.sequence_state_dim
pairwise_state_dim = config.pairwise_state_dim
sequence_num_heads = sequence_state_dim // config.sequence_head_width
pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
self.seq_attention = EsmFoldSelfAttention(
sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
)
self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
self.tri_att_start = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
)
self.tri_att_end = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
)
self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
self.drop = nn.Dropout(config.dropout)
self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
"""
Inputs:
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
tensor of valid positions
Output:
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
"""
if len(sequence_state.shape) != 3:
raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
if len(pairwise_state.shape) != 4:
raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
if mask is not None and len(mask.shape) != 2:
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
pairwise_state_dim = pairwise_state.shape[3]
if sequence_state_dim != self.config.sequence_state_dim:
raise ValueError(
"`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got"
f"{sequence_state_dim} != {self.config.sequence_state_dim}."
)
if pairwise_state_dim != self.config.pairwise_state_dim:
raise ValueError(
"`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
)
if batch_dim != pairwise_state.shape[0]:
raise ValueError(
f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
f"{pairwise_state.shape[0]}."
)
if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
raise ValueError(
f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
)
# Update sequence state
bias = self.pair_to_sequence(pairwise_state)
# Self attention with bias + mlp.
y = self.layernorm_1(sequence_state)
y, _ = self.seq_attention(y, mask=mask, bias=bias)
sequence_state = sequence_state + self.drop(y)
sequence_state = self.mlp_seq(sequence_state)
# Update pairwise state
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
# Axial attention with triangular bias.
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
pairwise_state = pairwise_state + self.row_drop(
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
pairwise_state = pairwise_state + self.col_drop(
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
# MLP over pairs.
pairwise_state = self.mlp_pair(pairwise_state)
return sequence_state, pairwise_state
class EsmCategoricalMixture:
def __init__(self, param, bins=50, start=0, end=1):
# All tensors are of shape ..., bins.
self.logits = param
bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
self.v_bins = (bins[:-1] + bins[1:]) / 2
def log_prob(self, true):
# Shapes are:
# self.probs: ... x bins
# true : ...
true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
nll = self.logits.log_softmax(-1)
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
def mean(self):
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
def categorical_lddt(logits, bins=50):
# Logits are ..., 37, bins.
return EsmCategoricalMixture(logits, bins=bins).mean()
def get_axial_mask(mask):
"""
Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
Input:
mask: B x L tensor of booleans
Output:
mask: B x L x L tensor of booleans
"""
if mask is None:
return None
if len(mask.shape) != 2:
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
batch_dim, seq_dim = mask.shape
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
m = m.reshape(batch_dim * seq_dim, seq_dim)
return m
class EsmFoldRelativePosition(nn.Module):
def __init__(self, config):
super().__init__()
self.bins = config.position_bins
# Note an additional offset is used so that the 0th position
# is reserved for masked pairs.
self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
def forward(self, residue_index, mask=None):
"""
Input:
residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans
Output:
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
"""
if residue_index.dtype != torch.long:
raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
if mask is not None and residue_index.shape != mask.shape:
raise ValueError(
f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
)
diff = residue_index[:, None, :] - residue_index[:, :, None]
diff = diff.clamp(-self.bins, self.bins)
diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
if mask is not None:
mask = mask[:, None, :] * mask[:, :, None]
diff[mask == False] = 0 # noqa: E712
output = self.embedding(diff)
return output
class EsmFoldAngleResnetBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
self.relu = nn.ReLU()
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
a = self.relu(a)
a = self.linear_1(a)
a = self.relu(a)
a = self.linear_2(a)
return a + s_initial
class EsmFoldAngleResnet(nn.Module):
"""
Implements Algorithm 20, lines 11-14
"""
def __init__(self, config):
super().__init__()
self.config = config
self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.layers = nn.ModuleList()
for _ in range(config.num_resnet_blocks):
layer = EsmFoldAngleResnetBlock(config)
self.layers.append(layer)
self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
self.relu = nn.ReLU()
def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
[*, C_hidden] single embedding
s_initial:
[*, C_hidden] single embedding as of the start of the StructureModule
Returns:
[*, no_angles, 2] predicted angles
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
# the pretrained weights, I'm going with the source.
# [*, C_hidden]
s_initial = self.relu(s_initial)
s_initial = self.linear_initial(s_initial)
s = self.relu(s)
s = self.linear_in(s)
s = s + s_initial
for l in self.layers:
s = l(s)
s = self.relu(s)
# [*, no_angles * 2]
s = self.linear_out(s)
# [*, no_angles, 2]
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s**2, dim=-1, keepdim=True),
min=self.config.epsilon,
)
)
s = s / norm_denom
return unnormalized_s, s
class EsmFoldInvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_dim
c_z = config.pairwise_dim
self.hidden_dim = config.ipa_dim
self.num_heads = config.num_heads_ipa
self.num_qk_points = config.num_qk_points
self.num_v_points = config.num_v_points
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = config.ipa_dim * config.num_heads_ipa
self.linear_q = EsmFoldLinear(c_s, hc)
self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
hpq = config.num_heads_ipa * config.num_qk_points * 3
self.linear_q_points = EsmFoldLinear(c_s, hpq)
hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
z = [z]
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.num_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.hidden_dim, dim=-1)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if _offload_inference:
assert sys.getrefcount(z[0]) == 2
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
)
else:
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.hidden_dim))
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.config.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if _offload_inference:
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
)
return s
class EsmFoldBackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
"""
def __init__(self, config):
super().__init__()
self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res, 6] update vector
"""
# [*, 6]
update = self.linear(s)
return update
class EsmFoldStructureModuleTransitionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s_initial = s
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
s = s + s_initial
return s
class EsmFoldStructureModuleTransition(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList()
for _ in range(config.num_transition_layers):
l = EsmFoldStructureModuleTransitionLayer(config)
self.layers.append(l)
self.dropout = nn.Dropout(config.dropout_rate)
self.layer_norm = LayerNorm(config.sequence_dim)
def forward(self, s):
for l in self.layers:
s = l(s)
s = self.dropout(s)
s = self.layer_norm(s)
return s
class EsmFoldStructureModule(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Buffers to be lazily initialized later
# self.default_frames
# self.group_idx
# self.atom_mask
# self.lit_positions
self.layer_norm_s = LayerNorm(config.sequence_dim)
self.layer_norm_z = LayerNorm(config.pairwise_dim)
self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
self.ipa = EsmFoldInvariantPointAttention(config)
self.ipa_dropout = nn.Dropout(config.dropout_rate)
self.layer_norm_ipa = LayerNorm(config.sequence_dim)
self.transition = EsmFoldStructureModuleTransition(config)
self.bb_update = EsmFoldBackboneUpdate(config)
self.angle_resnet = EsmFoldAngleResnet(config)
def forward(
self,
evoformer_output_dict,
aatype,
mask=None,
_offload_inference=False,
):
"""
Args:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if _offload_inference:
assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = []
for i in range(self.config.num_blocks):
# [*, N, C_s]
s = s + self.ipa(
s,
z,
rigids,
mask,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list,
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
preds = {
"frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"states": s,
}
outputs.append(preds)
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if _offload_inference:
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def _init_residue_constants(self, float_dtype, device):
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
residue_constants.restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
residue_constants.restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
residue_constants.restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
residue_constants.restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
class EsmFoldingTrunk(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_state_dim
c_z = config.pairwise_state_dim
self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
self.recycle_bins = 15
self.recycle_s_norm = nn.LayerNorm(c_s)
self.recycle_z_norm = nn.LayerNorm(c_z)
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
self.recycle_disto.weight[0].detach().zero_()
self.structure_module = EsmFoldStructureModule(config.structure_module)
self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
self.chunk_size = config.chunk_size
def set_chunk_size(self, chunk_size):
# This parameter means the axial attention will be computed
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
# It's equivalent to running a for loop over chunks of the dimension we're iterative over,
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
self.chunk_size = chunk_size
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
"""
Inputs:
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
Output:
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
"""
device = seq_feats.device
s_s_0 = seq_feats
s_z_0 = pair_feats
if no_recycles is None:
no_recycles = self.config.max_recycles
else:
if no_recycles < 0:
raise ValueError("Number of recycles must not be negative.")
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
def trunk_iter(s, z, residx, mask):
z = z + self.pairwise_positional_embedding(residx, mask=mask)
for block in self.blocks:
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
return s, z
s_s = s_s_0
s_z = s_z_0
recycle_s = torch.zeros_like(s_s)
recycle_z = torch.zeros_like(s_z)
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach())
recycle_z = self.recycle_z_norm(recycle_z.detach())
recycle_z += self.recycle_disto(recycle_bins.detach())
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
# === Structure module ===
structure = self.structure_module(
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
true_aa,
mask.float(),
)
recycle_s = s_s
recycle_z = s_z
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
recycle_bins = EsmFoldingTrunk.distogram(
structure["positions"][-1][:, :, :3],
3.375,
21.375,
self.recycle_bins,
)
structure["s_s"] = s_s
structure["s_z"] = s_z
return structure
@staticmethod
def distogram(coords, min_bin, max_bin, num_bins):
# Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
boundaries = torch.linspace(
min_bin,
max_bin,
num_bins - 1,
device=coords.device,
)
boundaries = boundaries**2
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
# Infer CB coordinates.
b = CA - N
c = C - CA
a = b.cross(c, dim=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
return bins
# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
# the outputs for downstream use.
@add_start_docstrings(
"""
ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
protein(s).
""",
ESM_START_DOCSTRING,
)
class EsmForProteinFolding(EsmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.distogram_bins = 64
self.esm = EsmModel(config, add_pooling_layer=False)
self.esm.requires_grad_(False)
if self.config.esmfold_config.fp16_esm:
self.esm.half()
self.esm_feats = self.config.hidden_size
self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
self.esm_layers = self.config.num_hidden_layers
self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
trunk_config = self.config.esmfold_config.trunk
c_s = trunk_config.sequence_state_dim
c_z = trunk_config.pairwise_state_dim
self.esm_s_mlp = nn.Sequential(
LayerNorm(self.esm_feats),
nn.Linear(self.esm_feats, c_s),
nn.ReLU(),
nn.Linear(c_s, c_s),
)
# 0 is padding, N is unknown residues, N + 1 is mask.
self.n_tokens_embed = residue_constants.restype_num + 3
self.pad_idx = 0
self.unk_idx = self.n_tokens_embed - 2
self.mask_idx = self.n_tokens_embed - 1
self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
if self.config.esmfold_config.embed_aa:
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
self.trunk = EsmFoldingTrunk(trunk_config)
self.distogram_head = nn.Linear(c_z, self.distogram_bins)
self.ptm_head = nn.Linear(c_z, self.distogram_bins)
self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
self.lddt_bins = 50
structure_module_config = trunk_config.structure_module
self.lddt_head = nn.Sequential(
nn.LayerNorm(structure_module_config.sequence_dim),
nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
)
@staticmethod
def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
# Remember that t is shifted from residue_constants by 1 (0 is padding).
esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
return torch.tensor(esm_reorder)
@add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
):
r"""
Returns:
Example:
TODO Matt
"""
cfg = self.config.esmfold_config
aa = input_ids # B x L
B = aa.shape[0]
L = aa.shape[1]
device = input_ids.device
if attention_mask is None:
attention_mask = torch.ones_like(aa, device=device)
if position_ids is None:
position_ids = torch.arange(L, device=device).expand_as(input_ids)
# === ESM ===
esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
if masking_pattern is not None:
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
else:
masked_aa = aa
mlm_targets = None
# We get sequence and pair representations from whatever version of ESM /
# configuration we are using. The sequence representation esm_s is always
# present. The pair embedding esm_z may be present depending on the
# configuration of the model. If esm_z is not used by the model then it
# is returned as None here.
esm_s = self.compute_language_model_representations(esmaa)
# Convert esm_s and esm_z, if present, to the precision used by the trunk and
# the structure module. These tensors may be a lower precision if, for example,
# we're running the language model in fp16 precision.
esm_s = esm_s.to(self.esm_s_combine.dtype)
if cfg.esm_ablate_sequence:
esm_s = esm_s * 0
esm_s = esm_s.detach()
# === preprocessing ===
esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
s_s_0 = self.esm_s_mlp(esm_s)
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
if self.config.esmfold_config.embed_aa:
s_s_0 += self.embedding(masked_aa)
structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
# Documenting what we expect:
structure = {
k: v
for k, v in structure.items()
if k
in [
"s_z",
"s_s",
"frames",
"sidechain_frames",
"unnormalized_angles",
"angles",
"positions",
"states",
]
}
# Add BERT mask for the loss to use, if available.
if mlm_targets:
structure["mlm_targets"] = mlm_targets
disto_logits = self.distogram_head(structure["s_z"])
disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
structure["distogram_logits"] = disto_logits
lm_logits = self.lm_head(structure["s_s"])
structure["lm_logits"] = lm_logits
structure["aatype"] = aa
make_atom14_masks(structure)
# Of course, this doesn't respect the true mask because it doesn't know about it...
# We're not going to properly mask change of index tensors:
# "residx_atom14_to_atom37",
# "residx_atom37_to_atom14",
for k in [
"atom14_atom_exists",
"atom37_atom_exists",
]:
structure[k] *= attention_mask.unsqueeze(-1)
structure["residue_index"] = position_ids
lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
structure["lddt_head"] = lddt_head
plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
structure["plddt"] = plddt
ptm_logits = self.ptm_head(structure["s_z"])
structure["ptm_logits"] = ptm_logits
structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
return EsmForProteinFoldingOutput(**structure)
def af2_idx_to_esm_idx(self, aa, mask):
aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa]
def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
B, L = esmaa.shape # B = batch size, L = sequence length.
if self.config.esmfold_config.bypass_lm:
esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
return esm_s
bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
bos = esmaa.new_full((B, 1), bosi)
eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
esmaa = torch.cat([bos, esmaa, eos], dim=1)
# Use the first padding index as eos during inference.
esmaa[range(B), (esmaa != 1).sum(1)] = eosi
# _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
# Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
# esm_z is always None
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
esm_s = torch.stack(esm_hidden_states, dim=2)
esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
return esm_s
def bert_mask(self, aa, esmaa, mask, pattern):
new_aa = aa.clone()
target = aa.clone()
new_esmaa = esmaa.clone()
new_aa[pattern == 1] = self.mask_idx
target[pattern != 1] = 0
new_esmaa[pattern == 1] = self.esm_dict_mask_idx
return new_aa, new_esmaa, target
@torch.no_grad()
def infer(
self,
seqs: Union[str, List[str]],
residx=None,
with_mask: Optional[torch.Tensor] = None,
):
if type(seqs) is str:
lst = [seqs]
else:
lst = seqs
# Returns the raw outputs of the model given an input sequence.
device = next(self.parameters()).device
aatype = collate_dense_tensors(
[
torch.from_numpy(
residue_constants.sequence_to_onehot(
sequence=seq,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
)
.to(device)
.argmax(dim=1)
for seq in lst
]
) # B=1 x L
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
residx = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device)
)
if residx.ndim == 1:
residx = residx.unsqueeze(0)
return self.forward(
aatype,
mask,
mask_aa=with_mask is not None,
masking_pattern=with_mask,
residx=residx,
)
@staticmethod
def output_to_pdb(output: Dict) -> List[str]:
"""Returns the pbd (file) string from the model given the model output."""
output = {k: v.to("cpu").numpy() for k, v in output.items()}
pdbs = []
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
final_atom_mask = output["atom37_atom_exists"]
for i in range(output["aatype"].shape[0]):
aa = output["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = output["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=output["plddt"][i],
)
pdbs.append(to_pdb(pred))
return pdbs
def infer_pdb(self, seqs, *args, **kwargs) -> str:
"""Returns the pdb (file) string from the model given an input sequence."""
assert type(seqs) is str
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)[0]
def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
"""Returns the pdb (file) string from the model given an input sequence."""
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)
# flake8: noqa
from .chunk_utils import chunk_layer
from .data_transforms import make_atom14_masks
from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import torch
from .tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if start_edges is None:
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if end_edges is None:
end_edges = [e == (d - 1) for e, d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if len(start) == 0:
return [tuple()]
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s, e in zip(start, end):
if s == e:
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if divergence_idx == len(dims):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if start_edges[divergence_idx] and end_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif start_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif end_edges[divergence_idx]:
slices.extend(upper())
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if middle_ground > 1:
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
and dicts with torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
slower than the default setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
def _select_chunk(t):
return t[i : i + chunk_size] if t.shape[0] != 1 else t
i = 0
out = prepped_outputs
for _ in range(no_chunks):
# Chunk the input
if not low_mem:
select_chunk = _select_chunk
else:
select_chunk = partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims),
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
if _add_into_out:
v[i : i + chunk_size] += d2[k]
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)
return out
class ChunkSizeTuner:
def __init__(
self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size:
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if not viable:
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple:
consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict:
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(
self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if not consistent:
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
)
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein["aatype"].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
dim=-2,
no_batch_dims=len(atom14.shape[:-2]),
)
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
torsion_angles_mask,
],
dim=-1,
)
return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1))
to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1))
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
rigids = Rigid.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
if not use_unit_vector:
unit_vector = unit_vector * 0.0
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
return act
def build_extra_msa_feat(batch):
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)
def torsion_angles_to_frames(
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
# [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple
import torch
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
(predicted_aligned_error, max_predicted_aligned_error,) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import re
import string
from typing import Any, Mapping, Optional, Sequence
import numpy as np
from . import residue_constants
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark: Optional[str] = None
# Templates used to generate this protein (prediction-only)
parents: Optional[Sequence[str]] = None
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r"(\[[A-Z]+\]\n)"
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
atoms = ["N", "CA", "C"]
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if "[PRIMARY]" == g[0]:
seq = g[1][0].strip()
for i in range(len(seq)):
if seq[i] not in residue_constants.restypes:
seq[i] = "X"
aatype = np.array(
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
)
elif "[TERTIARY]" == g[0]:
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
atom_positions *= PICO_TO_ANGSTROM
elif "[MASK]" == g[0]:
mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(
len(mask),
residue_constants.atom_type_num,
)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
remark = prot.remark
if remark is not None:
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if parents_chain_index is not None:
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
if parents is None or len(parents) == 0:
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines = []
lines = pdb_str.split("\n")
remark = prot.remark
if remark is not None:
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None
if prot.parents is not None and len(prot.parents) > 0:
parents_per_chain = []
if prot.parents_chain_index is not None:
parent_dict = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(prot.parents)
else:
parents_per_chain = [["N/A"]]
def make_parent_line(p):
return f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if "PARENT" not in l and "REMARK" not in l:
out_pdb_lines.append(l)
if "TER" in l and "END" not in lines[i + 1]:
chain_counter += 1
if not chain_counter >= len(parents_per_chain):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return "\n".join(out_pdb_lines)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
def res_1to3(r):
return residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
headers = get_pdb_headers(prot)
if len(headers) > 0:
pdb_lines.extend(headers)
n = aatype.shape[0]
atom_index = 1
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites.
for i in range(n):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ""
chain_tag = "A"
if chain_index is not None:
chain_tag = chain_tags[chain_index[i]]
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
should_terminate = i == n - 1
if chain_index is not None:
if i != n - 1 and chain_index[i + 1] != prev_chain_index:
should_terminate = True
prev_chain_index = chain_index[i + 1]
if should_terminate:
# Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
atom_index += 1
if i != n - 1:
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
computes a mask according to heavy atoms that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import collections
import copy
import functools
from importlib import resources
from typing import List, Mapping, Tuple
import numpy as np
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"N",
"O",
"OH",
],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
BondAngle = collections.namedtuple(
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
def map_structure_with_atom_order(in_list: List, first_call: bool = True):
# Maps strings in a nested list structure to their corresponding index in atom_order
if first_call:
in_list = copy.deepcopy(in_list)
for i in range(len(in_list)):
if isinstance(in_list[i], list):
in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False)
elif isinstance(in_list[i], str):
in_list[i] = atom_order[in_list[i]]
else:
raise ValueError("Unexpected type when mapping nested lists!")
return in_list
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate bond angles into the length of the opposite
edge of the triangle ("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples residue_virtual_bonds: dict that maps resname -->
list of Bond tuples residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma))
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev))
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"ND2",
"",
"",
"",
"",
"",
"",
],
"ASP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"OD2",
"",
"",
"",
"",
"",
"",
],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"CD1",
"",
"",
"",
"",
"",
"",
],
"LEU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"",
"",
"",
"",
"",
"",
],
"LYS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"CE",
"NZ",
"",
"",
"",
"",
"",
],
"MET": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"SD",
"CE",
"",
"",
"",
"",
"",
"",
],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": [
"N",
"CA",
"C",
"O",
"CB",
"OG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown.
If False, any amino acid not in the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s"
% sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(f"Invalid character in the sequence: {aa_type}")
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices_ours = map_structure_with_atom_order(chi_angles_atom_indices)
chi_angles_atom_indices = np.array(
[chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name] for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1))
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype):
return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import lru_cache
from typing import Any, Callable, Optional, Sequence, Tuple
import numpy as np
import torch
def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
def row_mul(i):
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack(
[
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2,
)
def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
rots = rots.contiguous()
return rots
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad)
return trans
@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if rot.shape[-2:] != (3, 3):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[
xx + yy + zz,
zy - yz,
xz - zx,
yx - xy,
],
[
zy - yz,
xx - yy - zz,
xy + yx,
xz + zx,
],
[
xz - zx,
xy + yx,
yy - xx - zz,
yz + zy,
],
[
yx - xy,
xz + zx,
yz + zy,
zz - xx - yy,
],
]
k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
}
@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix
or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the
underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the
behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
"""
def __init__(
self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit
quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None):
raise ValueError("Exactly one input argument must be specified")
if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
raise ValueError("Incorrectly shaped rotation matrix or quaternion")
# Force full-precision
if quats is not None:
quats = quats.to(dtype=torch.float32)
if rot_mats is not None:
rot_mats = rot_mats.to(dtype=torch.float32)
if quats is not None and normalize_quats:
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation
Returns:
A new identity rotation
"""
if fmt == "rot_mat":
rot_mats = identity_rot_mats(
shape,
dtype,
device,
requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif fmt == "quat":
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape
property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if self._rot_mats is not None:
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif self._quats is not None:
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(
self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not (isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if self._rot_mats is not None:
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(
self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the
underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix
tensor, for example, the resulting shape would be [10].
Returns:
The virtual shape of the rotation object
"""
s = None
if self._quats is not None:
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if self._rot_mats is not None:
return self._rot_mats.dtype
elif self._quats is not None:
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if self._rot_mats is not None:
return self._rot_mats.device
elif self._quats is not None:
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if self._rot_mats is not None:
return self._rot_mats.requires_grad
elif self._quats is not None:
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if rot_mats is None:
if self._quats is None:
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if quats is None:
if self._rot_mats is None:
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if self._rot_mats is not None:
return self._rot_mats
elif self._quats is not None:
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion
update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the
desired (not necessarily unit) quaternion update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those of another.
Depending on whether either Rotation was initialized with quaternions, this function may call
torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if self._rot_mats is not None:
return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(
self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if self._rot_mats is not None:
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
Note that the output of this operation is always a rotation matrix, regardless of the format of input
rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
be used e.g. to sum out a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if self._rot_mats is not None:
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if self._rot_mats is not None:
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif self._quats is not None:
return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if self._rot_mats is not None:
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached from its torch graph
"""
if self._rot_mats is not None:
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a
[*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch
dimensions of its component parts.
"""
def __init__(
self,
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
# (we need device, dtype, etc. from at least one input)
batch_dims, dtype, device, requires_grad = None, None, None, None
if trans is not None:
batch_dims = trans.shape[:-1]
dtype = trans.dtype
device = trans.device
requires_grad = trans.requires_grad
elif rots is not None:
batch_dims = rots.shape
dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if rots is None:
rots = Rotation.identity(
batch_dims,
dtype,
device,
requires_grad,
)
elif trans is None:
trans = identity_trans(
batch_dims,
dtype,
device,
requires_grad,
)
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(
self,
index: Any,
) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
both the rotation and the translation.
E.g.::
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) t = Rigid(r, torch.rand(10, 10, 3)) indexed =
t[3, 4:6] assert(indexed.shape == (2,)) assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if type(index) != tuple:
index = (index,)
return Rigid(
self._rots[index],
self._trans[index + (slice(None),)],
)
def __mul__(
self,
right: torch.Tensor,
) -> Rigid:
"""
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not (isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
return Rigid(new_rots, new_trans)
def __rmul__(
self,
left: torch.Tensor,
) -> Rigid:
"""
Reverse pointwise multiplication of the transformation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
"""
Returns the shape of the shared dimensions of the rotation and the translation.
Returns:
The shape of the transformation
"""
s = self._trans.shape[:-1]
return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
Returns:
The rotation object
"""
return self._rots
def get_trans(self) -> torch.Tensor:
"""
Getter for the translation.
Returns:
The stored translation
"""
return self._trans
def compose_q_update_vec(
self,
q_update_vec: torch.Tensor,
) -> Rigid:
"""
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
Args:
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
return Rigid(new_rots, new_translation)
def compose(
self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(
self,
pts: torch.Tensor,
) -> torch.Tensor:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
rotated = self._rots.apply(pts)
return rotated + self._trans
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
pts = pts - self._trans
return self._rots.invert_apply(pts)
def invert(self) -> Rigid:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv = self._rots.invert()
trn_inv = rot_inv.apply(self._trans)
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
translation/rotation dimensions respectively.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The transformed Rigid object
"""
new_rots = self._rots.map_tensor_fn(fn)
new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1)
return Rigid(new_rots, new_trans)
def to_tensor_4x4(self) -> torch.Tensor:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor = self._trans.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._rots.get_rot_mats()
tensor[..., :3, 3] = self._trans
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_tensor_4x4(t: torch.Tensor) -> Rigid:
"""
Constructs a transformation from a homogenous transformation tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
if t.shape[-2:] != (4, 4):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3]
return Rigid(rots, trans)
def to_tensor_7(self) -> torch.Tensor:
"""
Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the
translation.
Returns:
A [*, 7] tensor representation of the transformation
"""
tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :4] = self._rots.get_quats()
tensor[..., 4:] = self._trans
return tensor
@staticmethod
def from_tensor_7(
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
if t.shape[-1] != 7:
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
return Rigid(rots, trans)
@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8
) -> Rigid:
"""
Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
origin: [*, 3] coordinates used as frame origins
p_xy_plane: [*, 3] coordinates
eps: Small epsilon value
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(
self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod
def cat(
ts: Sequence[Rigid],
dim: int,
) -> Rigid:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be concatenated
Returns:
A concatenated transformation object
"""
rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
"""
Applies a Rotation -> Rotation function to the stored rotation object.
Args:
fn: A function of type Rotation -> Rotation
Returns:
A transformation object with a transformed rotation.
"""
return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
Args:
fn:
A function of type Tensor -> Tensor to be applied to the translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def scale_translation(self, trans_scale_factor: float) -> Rigid:
"""
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A transformation object with a scaled translation.
"""
return self.apply_trans_fn(lambda t: t * trans_scale_factor)
def stop_rot_gradient(self) -> Rigid:
"""
Detaches the underlying rotation object
Returns:
A transformation object with detached rotations
"""
return self.apply_rot_fn(lambda r: r.detach())
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard
way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and rotation to the reference backbone, the
coordinates will approximately equal to the input coordinates.
"""
translation = -1 * ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
sin_c2 = c_z / norm
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y**2 + n_z**2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
rots = rot_matmul(n_rots, c_rots)
rots = rots.transpose(-1, -2)
translation = -1 * translation
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> Rigid:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return Rigid(self._rots.cuda(), self._trans.cuda())
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import List
import torch
import torch.nn as nn
def add(m1, m2, inplace):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if not inplace:
m1 = m1 + m2
else:
m1 += m2
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
# Matt note: Editing this to get around the behaviour of using a list as an array index changing
# in recent Numpy versions
return data[tuple(ranges)]
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
return new_dict
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
...@@ -1985,6 +1985,13 @@ class ErniePreTrainedModel(metaclass=DummyObject): ...@@ -1985,6 +1985,13 @@ class ErniePreTrainedModel(metaclass=DummyObject):
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class EsmFoldPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EsmForMaskedLM(metaclass=DummyObject): class EsmForMaskedLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1992,6 +1999,13 @@ class EsmForMaskedLM(metaclass=DummyObject): ...@@ -1992,6 +1999,13 @@ class EsmForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class EsmForProteinFolding(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EsmForSequenceClassification(metaclass=DummyObject): class EsmForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -20,7 +20,6 @@ import unittest ...@@ -20,7 +20,6 @@ import unittest
from transformers import EsmConfig, is_torch_available from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
...@@ -49,7 +48,7 @@ class EsmModelTester: ...@@ -49,7 +48,7 @@ class EsmModelTester:
self.use_input_mask = True self.use_input_mask = True
self.use_token_type_ids = False self.use_token_type_ids = False
self.use_labels = True self.use_labels = True
self.vocab_size = 99 self.vocab_size = 33
self.hidden_size = 32 self.hidden_size = 32
self.num_hidden_layers = 5 self.num_hidden_layers = 5
self.num_attention_heads = 4 self.num_attention_heads = 4
...@@ -145,7 +144,7 @@ class EsmModelTester: ...@@ -145,7 +144,7 @@ class EsmModelTester:
@require_torch @require_torch
class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class EsmModelTest(ModelTesterMixin, unittest.TestCase):
test_mismatched_shapes = False test_mismatched_shapes = False
...@@ -253,7 +252,9 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -253,7 +252,9 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class EsmModelIntegrationTest(TestCasePlus): class EsmModelIntegrationTest(TestCasePlus):
@slow @slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0] output = model(input_ids)[0]
...@@ -269,7 +270,9 @@ class EsmModelIntegrationTest(TestCasePlus): ...@@ -269,7 +270,9 @@ class EsmModelIntegrationTest(TestCasePlus):
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
with torch.no_grad():
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
output = model(input_ids)[0] output = model(input_ids)[0]
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch ESM model. """
import unittest
from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
import torch
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
class EsmFoldModelTester:
def __init__(
self,
parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = False
self.use_input_mask = True
self.use_token_type_ids = False
self.use_labels = False
self.vocab_size = 19
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
config = EsmConfig(
vocab_size=33,
hidden_size=self.hidden_size,
pad_token_id=1,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
is_folding_model=True,
esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False},
)
return config
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = EsmForProteinFolding(config=config).float()
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
result = model(input_ids)
self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3))
self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
test_mismatched_shapes = False
all_model_classes = (EsmForProteinFolding,) if is_torch_available() else ()
all_generative_model_classes = ()
test_sequence_classification_problem_types = False
def setUp(self):
self.model_tester = EsmFoldModelTester(self)
self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Does not support attention outputs")
def test_attention_outputs(self):
pass
@unittest.skip
def test_correct_missing_keys(self):
pass
@unittest.skip("Esm does not support embedding resizing")
def test_resize_embeddings_untied(self):
pass
@unittest.skip("Esm does not support embedding resizing")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip("ESMFold does not support passing input embeds!")
def test_inputs_embeds(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_integration(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_save_load_from_config_init(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_head_pruning_save_load_from_pretrained(self):
pass
@unittest.skip("ESMFold does not support head pruning.")
def test_headmasking(self):
pass
@unittest.skip("ESMFold does not output hidden states in the normal way.")
def test_hidden_states_output(self):
pass
@unittest.skip("ESMfold does not output hidden states in the normal way.")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip("ESMFold only has one output format.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip("This test doesn't work for ESMFold and doesn't test core functionality")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("ESMFold does not support input chunking.")
def test_feed_forward_chunking(self):
pass
@unittest.skip("ESMFold doesn't respect you and it certainly doesn't respect your initialization arguments.")
def test_initialization(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip("ESMFold doesn't support torchscript compilation.")
def test_torchscript_simple(self):
pass
@unittest.skip("ESMFold doesn't support data parallel.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_protein_folding(self):
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
position_outputs = model(input_ids)["positions"]
expected_slice = torch.tensor([2.5828, 0.7993, -10.9334], dtype=torch.float32)
self.assertTrue(torch.allclose(position_outputs[0, 0, 0, 0], expected_slice, atol=1e-4))
...@@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tf @require_tf
class TFEsmModelIntegrationTest(unittest.TestCase): class TFEsmModelIntegrationTest(unittest.TestCase):
@slow @unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
...@@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase): ...@@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
) )
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4)) self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@slow @unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_no_head(self): def test_inference_no_head(self):
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment