Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
cff-version: 1.2.0 cff-version: 1.2.0
message: "For now, cite OpenFold with its DOI." preferred-citation:
authors: authors:
- family-names: "Ahdritz" - family-names: "Ahdritz"
given-names: "Gustaf" given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324 orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta" - family-names: "Bouatta"
given-names: "Nazim" given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan" - family-names: "Kadyan"
given-names: "Sachin" given-names: "Sachin"
- family-names: "Xia" orcid: https://orcid.org/0000-0002-6079-7627
- family-names: "Xia"
given-names: "Qinghui" given-names: "Qinghui"
- family-names: "Gerecke" - family-names: "Gerecke"
given-names: "William" given-names: "William"
- family-names: "AlQuraishi" orcid: https://orcid.org/0000-0002-9777-6192
- family-names: "O'Donnell"
given-names: "Timothy J"
orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed" given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322 orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold" title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
doi: 10.5281/zenodo.5709539 type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12 date-released: 2021-11-12
url: "https://github.com/aqlaboratory/openfold" url: "https://doi.org/10.1101/2022.11.20.517210"
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
RUN apt-get update && apt-get install -y wget cuda-minimal-build-10-2 git # metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
......
This diff is collapsed.
...@@ -4,9 +4,19 @@ channels: ...@@ -4,9 +4,19 @@ channels:
- bioconda - bioconda
- pytorch - pytorch
dependencies: dependencies:
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==11.3.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip: - pip:
- biopython==1.79 - biopython==1.79
- deepspeed==0.5.9 - deepspeed==0.5.10
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- numpy==1.21.2 - numpy==1.21.2
...@@ -16,15 +26,5 @@ dependencies: ...@@ -16,15 +26,5 @@ dependencies:
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10 - pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- pytorch::pytorch=1.10.*
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==10.2.*
- conda-forge::cudatoolkit-dev==10.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
This diff is collapsed.
name: openfold_venv
channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
import copy import copy
import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -10,20 +11,92 @@ def set_inf(c, inf): ...@@ -10,20 +11,92 @@ def set_inf(c, inf):
c[k] = inf c[k] = inf
def model_config(name, train=False, low_prec=False): def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting.get(p)
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
),
(
"globals.use_lma",
"globals.use_flash",
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training": if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting # AF2 Suppl. Table 4, "initial training" setting
pass pass
elif name == "finetuning": elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False): ...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True c.model.template.enabled = True
elif name == "model_3": elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1 # AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_4": elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2 # AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_5": elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3 # AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_1_ptm": elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False): ...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_3_ptm": elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_4_ptm": elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
for k,v in multimer_model_config_update.items(): for k,v in multimer_model_config_update.items():
c.model[k] = v c.model[k] = v
...@@ -89,9 +168,23 @@ def model_config(name, train=False, low_prec=False): ...@@ -89,9 +168,23 @@ def model_config(name, train=False, low_prec=False):
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
...@@ -99,6 +192,8 @@ def model_config(name, train=False, low_prec=False): ...@@ -99,6 +192,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
enforce_config_constraints(c)
return c return c
...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool) ...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder" NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder" NUM_MSA_SEQ = "msa placeholder"
...@@ -195,7 +291,6 @@ config = mlc.ConfigDict( ...@@ -195,7 +291,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -233,7 +328,8 @@ config = mlc.ConfigDict( ...@@ -233,7 +328,8 @@ config = mlc.ConfigDict(
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 512,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -246,6 +342,7 @@ config = mlc.ConfigDict( ...@@ -246,6 +342,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -258,6 +355,7 @@ config = mlc.ConfigDict( ...@@ -258,6 +355,7 @@ config = mlc.ConfigDict(
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
...@@ -274,6 +372,7 @@ config = mlc.ConfigDict( ...@@ -274,6 +372,7 @@ config = mlc.ConfigDict(
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 16, "num_workers": 16,
"pin_memory": True,
}, },
}, },
}, },
...@@ -281,6 +380,13 @@ config = mlc.ConfigDict( ...@@ -281,6 +380,13 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
...@@ -333,6 +439,7 @@ config = mlc.ConfigDict( ...@@ -333,6 +439,7 @@ config = mlc.ConfigDict(
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False, "tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
...@@ -349,6 +456,17 @@ config = mlc.ConfigDict( ...@@ -349,6 +456,17 @@ config = mlc.ConfigDict(
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False, "use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
...@@ -369,7 +487,8 @@ config = mlc.ConfigDict( ...@@ -369,7 +487,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False, "opm_first": False,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None, "ckpt": blocks_per_ckpt is not None,
...@@ -393,6 +512,7 @@ config = mlc.ConfigDict( ...@@ -393,6 +512,7 @@ config = mlc.ConfigDict(
"opm_first": False, "opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
...@@ -473,7 +593,7 @@ config = mlc.ConfigDict( ...@@ -473,7 +593,7 @@ config = mlc.ConfigDict(
"eps": 1e-4, "eps": 1e-4,
"weight": 1.0, "weight": 1.0,
}, },
"lddt": { "plddt_loss": {
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15.0, "cutoff": 15.0,
...@@ -482,6 +602,7 @@ config = mlc.ConfigDict( ...@@ -482,6 +602,7 @@ config = mlc.ConfigDict(
"weight": 0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 2.0, "weight": 2.0,
}, },
...@@ -503,7 +624,7 @@ config = mlc.ConfigDict( ...@@ -503,7 +624,7 @@ config = mlc.ConfigDict(
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
...@@ -607,6 +728,23 @@ multimer_model_config_update = { ...@@ -607,6 +728,23 @@ multimer_model_config_update = {
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": { "heads": {
"lddt": { "lddt": {
"no_bins": 50, "no_bins": 50,
......
...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None, filter_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None _structure_index: Optional[Any] = None,
): ):
""" """
Args: Args:
...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
...@@ -96,13 +111,41 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -96,13 +111,41 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None): if(alignment_index is not None):
self._chain_ids = list(_alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: self._chain_ids = list(os.listdir(alignment_dir))
self._chain_ids = [l.strip() for l in f.readlines()]
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None alignment_index = None
if(self._alignment_index is not None): if(self.alignment_index is not None):
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index, path, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb", pdb_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index, alignment_index=alignment_index,
_structure_index=structure_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Extension branch missing")
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index, alignment_index=alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats return feats
def __len__(self): def __len__(self):
...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -276,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -276,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
# Uniformly shuffle each dataset's indices # Uniformly shuffle each dataset's indices
...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx] chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
idx = [] idx = []
...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, stage="train"): def __call__(self, prots):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_filter_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None, _distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path self.train_filter_path = train_filter_path
self.distillation_mapping_path = distillation_mapping_path self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None self._distillation_structure_index = None
if(_alignment_index_path is not None): if(_distillation_structure_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._alignment_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, alignment_index=self.alignment_index,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, filter_path=self.distillation_filter_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None): if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path, generator = None
] if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
_output_raw=True,
) )
else: else:
self.eval_dataset = None self.eval_dataset = None
...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir, alignment_dir=self.predict_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.predict.max_template_hits, max_template_hits=self.config.predict.max_template_hits,
mode="predict", mode="predict",
) )
...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage) batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader( dl = OpenFoldDataLoader(
dataset, dataset,
......
...@@ -14,26 +14,17 @@ ...@@ -14,26 +14,17 @@
# limitations under the License. # limitations under the License.
import os import os
import copy
import collections import collections
import contextlib import contextlib
import dataclasses import dataclasses
import datetime
import json
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data import ( from openfold.data.templates import get_custom_template_features
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.parsers import Msa
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -78,6 +69,51 @@ def make_template_features( ...@@ -78,6 +69,51 @@ def make_template_features(
return template_features return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
if(n_templates != 0):
out_dicts.append(out_dict)
if(len(out_dicts) > 0):
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
else:
out_dict = empty_template_feats(sum(seq_lens))
return out_dict
def make_sequence_features( def make_sequence_features(
sequence: str, description: str, num_res: int sequence: str, description: str, num_res: int
) -> FeatureDict: ) -> FeatureDict:
...@@ -249,6 +285,41 @@ def run_msa_tool( ...@@ -249,6 +285,41 @@ def run_msa_tool(
return result return result
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [sequence]
deletion_matrix = [[0 for _ in sequence]]
msa_data_obj = parsers.Msa(sequences=msa_data, deletion_matrix=deletion_matrix, descriptions=None)
msa_features = make_msa_features([msa_data_obj])
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner: class AlignmentRunner:
"""Runs alignment tools and saves the results""" """Runs alignment tools and saves the results"""
...@@ -617,32 +688,30 @@ class DataPipeline: ...@@ -617,32 +688,30 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = {} msa_data = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
fp.seek(start) fp.seek(start)
msa = fp.read(size).decode("utf-8") msa = fp.read(size).decode("utf-8")
return msa return msa
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name) filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
# The "hmm_output" exception is a crude way to exclude # The "hmm_output" exception is a crude way to exclude
# multimer template hits. # multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename): elif(ext == ".sto" and not "hmm_output" == filename):
msa, deletion_matrix, _ = parsers.parse_stockholm( msa = parsers.parse_stockholm(read_msa(start, size))
read_msa(start, size) data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
...@@ -657,33 +726,35 @@ class DataPipeline: ...@@ -657,33 +726,35 @@ class DataPipeline:
if(ext == ".a3m"): if(ext == ".a3m"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
elif(ext == ".sto" and not "hmm_output" == filename): elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_stockholm( msa = parsers.parse_stockholm(
fp.read() fp.read()
) )
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else: else:
continue continue
msas[f] = msa msa_data[f] = data
return msas return msa_data
def _parse_template_hit_files( def _parse_template_hit_files(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: str, input_sequence: str,
_alignment_index: Optional[Any] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size): def read_template(start, size):
fp.seek(start) fp.seek(start)
return fp.read(size).decode("utf-8") return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1] ext = os.path.splitext(name)[-1]
if(ext == ".hhr"): if(ext == ".hhr"):
...@@ -716,15 +787,46 @@ class DataPipeline: ...@@ -716,15 +787,46 @@ class DataPipeline:
return all_hits return all_hits
def _process_msa_feats( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, alignment_index: Optional[Any] = None
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = self._parse_msa_data(alignment_dir, _alignment_index) all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(len(msas) == 0): if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
raise ValueError( raise ValueError(
""" """
...@@ -732,13 +834,31 @@ class DataPipeline: ...@@ -732,13 +834,31 @@ class DataPipeline:
must be provided. must be provided.
""" """
) )
msa_data["dummy"] = Msa(
[input_sequence],
[[0 for _ in input_sequence]],
["dummy"]
)
msa_features = make_msa_features(list(msas.values())) deletion_matrix = [[0 for _ in input_sequence]]
msa_data["dummy"] = {
"msa": parsers.Msa(sequences=input_sequence, deletion_matrix=deletion_matrix, descriptions=None),
"deletion_matrix": deletion_matrix,
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
msas=msas
)
return msa_features return msa_features
...@@ -746,7 +866,7 @@ class DataPipeline: ...@@ -746,7 +866,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -763,7 +883,7 @@ class DataPipeline: ...@@ -763,7 +883,7 @@ class DataPipeline:
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index, alignment_index,
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -778,7 +898,7 @@ class DataPipeline: ...@@ -778,7 +898,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return { return {
**sequence_features, **sequence_features,
...@@ -791,7 +911,7 @@ class DataPipeline: ...@@ -791,7 +911,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -812,7 +932,8 @@ class DataPipeline: ...@@ -812,7 +932,8 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index) alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -820,7 +941,7 @@ class DataPipeline: ...@@ -820,7 +941,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -831,7 +952,7 @@ class DataPipeline: ...@@ -831,7 +952,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -861,15 +982,16 @@ class DataPipeline: ...@@ -861,15 +982,16 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -877,7 +999,7 @@ class DataPipeline: ...@@ -877,7 +999,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -892,9 +1014,9 @@ class DataPipeline: ...@@ -892,9 +1014,9 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, alignment_index
_alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -905,6 +1027,98 @@ class DataPipeline: ...@@ -905,6 +1027,98 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap).
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
final_msa_obj = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert (len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
final_msa_obj.extend([parsers.Msa(sequences=msas[k], deletion_matrix=deletion_mats[k], descriptions=None)
for k in range(len(msas))])
msa_features = make_msa_features(
msas=final_msa_obj
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
class DataPipelineMultimer: class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features.""" """Runs the alignment tools and assembles the input features."""
...@@ -913,7 +1127,6 @@ class DataPipelineMultimer: ...@@ -913,7 +1127,6 @@ class DataPipelineMultimer:
monomer_data_pipeline: DataPipeline, monomer_data_pipeline: DataPipeline,
): ):
"""Initializes the data pipeline. """Initializes the data pipeline.
Args: Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system. the data pipeline for the monomer AlphaFold system.
...@@ -955,10 +1168,12 @@ class DataPipelineMultimer: ...@@ -955,10 +1168,12 @@ class DataPipelineMultimer:
def _all_seq_msa_features(self, fasta_path, alignment_dir): def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing.""" """Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") #TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.a3m")
with open(uniprot_msa_path, "r") as fp: with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read() uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string) msa = parsers.parse_a3m(uniprot_msa_string)
#msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa]) all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
......
...@@ -23,6 +23,9 @@ import torch ...@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -736,6 +739,7 @@ def make_atom14_positions(protein): ...@@ -736,6 +739,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences): for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0 renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
...@@ -781,10 +785,14 @@ def make_atom14_positions(protein): ...@@ -781,10 +785,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8): def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
...@@ -831,6 +839,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -831,6 +839,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
base_atom_pos = [batched_gather(
pos,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather( base_atom_pos = batched_gather(
all_atom_positions, all_atom_positions,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
...@@ -838,6 +855,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -838,6 +855,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
if is_multimer:
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
...@@ -864,8 +890,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -864,8 +890,12 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
...@@ -900,6 +930,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -900,6 +930,12 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
else:
residx_rigidgroup_ambiguity_rot = Rotation( residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot rot_mats=residx_rigidgroup_ambiguity_rot
) )
......
...@@ -103,6 +103,21 @@ def np_example_to_features( ...@@ -103,6 +103,21 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
......
...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size( data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, mode_cfg.max_extra_msa,
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates, mode_cfg.max_templates,
) )
......
...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
......
...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool: ...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject,
chain_id: str, chain_id: str,
_zero_center_positions: bool = True _zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
......
...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("") sequences.append("")
continue continue
elif line.startswith("#"):
continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
sequences[index] += line sequences[index] += line
......
...@@ -128,6 +128,22 @@ def _is_after_cutoff( ...@@ -128,6 +128,22 @@ def _is_after_cutoff(
return False return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete.""" """Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f: with open(obsolete_file_path) as f:
...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
from_id = line[20:24].lower() from_id = line[20:24].lower()
to_id = line[29:33].lower() to_id = line[29:33].lower()
result[from_id] = to_id result[from_id] = to_id
return result return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str): def generate_release_dates_cache(mmcif_dir: str, out_path: str):
...@@ -495,7 +511,7 @@ def _get_atom_positions( ...@@ -495,7 +511,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True, _zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
...@@ -912,6 +928,56 @@ def _process_single_hit( ...@@ -912,6 +928,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
filtered = list( filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True) sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
) )
idx = list(range(len(filtered))) idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered): if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered stk = self._shuffle_top_k_prefiltered
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
...@@ -17,7 +17,7 @@ from functools import partial ...@@ -17,7 +17,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, Optional
from openfold.utils import all_atom_multimer from openfold.utils import all_atom_multimer
from openfold.utils.feats import ( from openfold.utils.feats import (
...@@ -32,7 +32,7 @@ from openfold.model.template import ( ...@@ -32,7 +32,7 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
) )
from openfold.utils import geometry from openfold.utils import geometry
from openfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -96,10 +96,21 @@ class InputEmbedder(nn.Module): ...@@ -96,10 +96,21 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries).type(ri.dtype) reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
return self.linear_relpos(oh) d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
batch: Dict containing batch: Dict containing
...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module): ...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
tf = batch["target_feat"]
ri = batch["residue_index"]
msa = batch["msa_feat"]
# [*, N_res, c_z] # [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf) tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module): ...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(inplace_safe):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace( bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module):
device=x.device, device=x.device,
requires_grad=False, requires_grad=False,
) )
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = d + self.layer_norm_z(z) z_update = add(z_update, d, inplace_safe)
return m_update, z_update return m_update, z_update
...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__( def __init__(
self, self,
c_in: int, c_in: int,
...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module): ...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True _mask_trans=True,
use_lma=False,
inplace_safe=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx), lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch, batch,
) )
single_template_embeds = {} # [*, N, N, C_t]
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module): ...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module):
).to(z.dtype) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
template_embeds.append(single_template_embeds) del t
template_embeds = dict_multimap( if (not inplace_safe):
partial(torch.cat, dim=templ_dim), t_pair = torch.stack(pair_embeds, dim=templ_dim)
template_embeds,
) del pair_embeds
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
del t_pair
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size, use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {} ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret return ret
...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
use_lma=False,
inplace_safe=False
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment