Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
cff-version: 1.2.0 cff-version: 1.2.0
message: "For now, cite OpenFold with its DOI." preferred-citation:
authors: authors:
- family-names: "Ahdritz" - family-names: "Ahdritz"
given-names: "Gustaf" given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324 orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta" - family-names: "Bouatta"
given-names: "Nazim" given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan" - family-names: "Kadyan"
given-names: "Sachin" given-names: "Sachin"
- family-names: "Xia" orcid: https://orcid.org/0000-0002-6079-7627
given-names: "Qinghui" - family-names: "Xia"
- family-names: "Gerecke" given-names: "Qinghui"
given-names: "William" - family-names: "Gerecke"
- family-names: "AlQuraishi" given-names: "William"
given-names: "Mohammed" orcid: https://orcid.org/0000-0002-9777-6192
orcid: https://orcid.org/0000-0001-6817-1322 - family-names: "O'Donnell"
title: "OpenFold" given-names: "Timothy J"
doi: 10.5281/zenodo.5709539 orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12 date-released: 2021-11-12
url: "https://github.com/aqlaboratory/openfold" url: "https://doi.org/10.1101/2022.11.20.517210"
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
RUN apt-get update && apt-get install -y wget cuda-minimal-build-10-2 git # metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
......
This diff is collapsed.
...@@ -4,9 +4,19 @@ channels: ...@@ -4,9 +4,19 @@ channels:
- bioconda - bioconda
- pytorch - pytorch
dependencies: dependencies:
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==11.3.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip: - pip:
- biopython==1.79 - biopython==1.79
- deepspeed==0.5.9 - deepspeed==0.5.10
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- numpy==1.21.2 - numpy==1.21.2
...@@ -16,15 +26,5 @@ dependencies: ...@@ -16,15 +26,5 @@ dependencies:
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10 - pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- pytorch::pytorch=1.10.*
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==10.2.*
- conda-forge::cudatoolkit-dev==10.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
This diff is collapsed.
name: openfold_venv
channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
import copy import copy
import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -10,20 +11,92 @@ def set_inf(c, inf): ...@@ -10,20 +11,92 @@ def set_inf(c, inf):
c[k] = inf c[k] = inf
def model_config(name, train=False, low_prec=False): def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting.get(p)
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
),
(
"globals.use_lma",
"globals.use_flash",
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training": if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting # AF2 Suppl. Table 4, "initial training" setting
pass pass
elif name == "finetuning": elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False): ...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True c.model.template.enabled = True
elif name == "model_3": elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1 # AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_4": elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2 # AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_5": elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3 # AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_1_ptm": elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False): ...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_3_ptm": elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_4_ptm": elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
for k,v in multimer_model_config_update.items(): for k,v in multimer_model_config_update.items():
c.model[k] = v c.model[k] = v
...@@ -89,16 +168,32 @@ def model_config(name, train=False, low_prec=False): ...@@ -89,16 +168,32 @@ def model_config(name, train=False, low_prec=False):
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be # If we want exact numerical parity with the original, inf can't be
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
enforce_config_constraints(c)
return c return c
...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool) ...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder" NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder" NUM_MSA_SEQ = "msa placeholder"
...@@ -195,7 +291,6 @@ config = mlc.ConfigDict( ...@@ -195,7 +291,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -233,7 +328,8 @@ config = mlc.ConfigDict( ...@@ -233,7 +328,8 @@ config = mlc.ConfigDict(
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 512,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -246,6 +342,7 @@ config = mlc.ConfigDict( ...@@ -246,6 +342,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -258,6 +355,7 @@ config = mlc.ConfigDict( ...@@ -258,6 +355,7 @@ config = mlc.ConfigDict(
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
...@@ -274,6 +372,7 @@ config = mlc.ConfigDict( ...@@ -274,6 +372,7 @@ config = mlc.ConfigDict(
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 16, "num_workers": 16,
"pin_memory": True,
}, },
}, },
}, },
...@@ -281,6 +380,13 @@ config = mlc.ConfigDict( ...@@ -281,6 +380,13 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
...@@ -333,6 +439,7 @@ config = mlc.ConfigDict( ...@@ -333,6 +439,7 @@ config = mlc.ConfigDict(
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False, "tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
...@@ -349,6 +456,17 @@ config = mlc.ConfigDict( ...@@ -349,6 +456,17 @@ config = mlc.ConfigDict(
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False, "use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
...@@ -369,7 +487,8 @@ config = mlc.ConfigDict( ...@@ -369,7 +487,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False, "opm_first": False,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None, "ckpt": blocks_per_ckpt is not None,
...@@ -393,6 +512,7 @@ config = mlc.ConfigDict( ...@@ -393,6 +512,7 @@ config = mlc.ConfigDict(
"opm_first": False, "opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
...@@ -473,7 +593,7 @@ config = mlc.ConfigDict( ...@@ -473,7 +593,7 @@ config = mlc.ConfigDict(
"eps": 1e-4, "eps": 1e-4,
"weight": 1.0, "weight": 1.0,
}, },
"lddt": { "plddt_loss": {
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15.0, "cutoff": 15.0,
...@@ -482,6 +602,7 @@ config = mlc.ConfigDict( ...@@ -482,6 +602,7 @@ config = mlc.ConfigDict(
"weight": 0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 2.0, "weight": 2.0,
}, },
...@@ -503,7 +624,7 @@ config = mlc.ConfigDict( ...@@ -503,7 +624,7 @@ config = mlc.ConfigDict(
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
...@@ -607,6 +728,23 @@ multimer_model_config_update = { ...@@ -607,6 +728,23 @@ multimer_model_config_update = {
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": { "heads": {
"lddt": { "lddt": {
"no_bins": 50, "no_bins": 50,
......
...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None, filter_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None _structure_index: Optional[Any] = None,
): ):
""" """
Args: Args:
...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
...@@ -96,14 +111,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -96,14 +111,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None): if(alignment_index is not None):
self._chain_ids = list(_alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: self._chain_ids = list(os.listdir(alignment_dir))
self._chain_ids = [l.strip() for l in f.readlines()]
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
} }
...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None alignment_index = None
if(self._alignment_index is not None): if(self.alignment_index is not None):
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index, path, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb", pdb_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index, alignment_index=alignment_index,
_structure_index=structure_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Extension branch missing")
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index, alignment_index=alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats return feats
def __len__(self): def __len__(self):
...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -275,11 +343,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -275,11 +343,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.probabilities = probabilities self.probabilities = probabilities
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx] chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
idx = [] idx = []
...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, stage="train"): def __call__(self, prots):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_filter_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None, _distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path self.train_filter_path = train_filter_path
self.distillation_mapping_path = distillation_mapping_path self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None self._distillation_structure_index = None
if(_alignment_index_path is not None): if(_distillation_structure_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._alignment_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, alignment_index=self.alignment_index,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, filter_path=self.distillation_filter_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None): if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
_output_raw=True,
) )
else: else:
self.eval_dataset = None self.eval_dataset = None
...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir, alignment_dir=self.predict_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.predict.max_template_hits, max_template_hits=self.config.predict.max_template_hits,
mode="predict", mode="predict",
) )
...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage) batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader( dl = OpenFoldDataLoader(
dataset, dataset,
......
This diff is collapsed.
...@@ -23,6 +23,9 @@ import torch ...@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -669,7 +672,7 @@ def make_atom14_masks(protein): ...@@ -669,7 +672,7 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch):
batch = tree_map( batch = tree_map(
lambda n: torch.tensor(n, device="cpu"), lambda n: torch.tensor(n, device="cpu"),
batch, batch,
np.ndarray np.ndarray
) )
...@@ -736,6 +739,7 @@ def make_atom14_positions(protein): ...@@ -736,6 +739,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences): for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0 renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
...@@ -781,10 +785,14 @@ def make_atom14_positions(protein): ...@@ -781,10 +785,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8): def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
...@@ -831,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -831,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
base_atom_pos = batched_gather( if is_multimer:
all_atom_positions, base_atom_pos = [batched_gather(
residx_rigidgroup_base_atom37_idx, pos,
dim=-2, residx_rigidgroup_base_atom37_idx,
no_batch_dims=len(all_atom_positions.shape[:-2]), dim=-1,
) no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points( if is_multimer:
p_neg_x_axis=base_atom_pos[..., 0, :], point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin=base_atom_pos[..., 1, :], origin = base_atom_pos[:, :, 1]
p_xy_plane=base_atom_pos[..., 2, :], point_on_xy_plane = base_atom_pos[:, :, 2]
eps=eps, gt_rotation = Rot3Array.from_two_vectors(
) origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather( group_exists = batched_gather(
restype_rigidgroup_mask, restype_rigidgroup_mask,
...@@ -864,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -864,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None)) if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
...@@ -900,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -900,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
residx_rigidgroup_ambiguity_rot = Rotation( if is_multimer:
rot_mats=residx_rigidgroup_ambiguity_rot ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
)
alt_gt_frames = gt_frames.compose( # Create the alternative ground truth frames.
Rigid(residx_rigidgroup_ambiguity_rot, None) alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
) else:
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4() gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
......
...@@ -103,6 +103,21 @@ def np_example_to_features( ...@@ -103,6 +103,21 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
......
...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size( data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, mode_cfg.max_extra_msa,
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates, mode_cfg.max_templates,
) )
......
...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
......
...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool: ...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject,
chain_id: str, chain_id: str,
_zero_center_positions: bool = True _zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
......
...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("") sequences.append("")
continue continue
elif line.startswith("#"):
continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
sequences[index] += line sequences[index] += line
......
...@@ -128,6 +128,22 @@ def _is_after_cutoff( ...@@ -128,6 +128,22 @@ def _is_after_cutoff(
return False return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete.""" """Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f: with open(obsolete_file_path) as f:
...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
from_id = line[20:24].lower() from_id = line[20:24].lower()
to_id = line[29:33].lower() to_id = line[29:33].lower()
result[from_id] = to_id result[from_id] = to_id
return result return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str): def generate_release_dates_cache(mmcif_dir: str, out_path: str):
...@@ -495,7 +511,7 @@ def _get_atom_positions( ...@@ -495,7 +511,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True, _zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
...@@ -912,6 +928,56 @@ def _process_single_hit( ...@@ -912,6 +928,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
filtered = list( filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True) sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
) )
idx = list(range(len(filtered))) idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered): if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered stk = self._shuffle_top_k_prefiltered
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
...@@ -17,7 +17,7 @@ from functools import partial ...@@ -17,7 +17,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, Optional
from openfold.utils import all_atom_multimer from openfold.utils import all_atom_multimer
from openfold.utils.feats import ( from openfold.utils.feats import (
...@@ -32,7 +32,7 @@ from openfold.model.template import ( ...@@ -32,7 +32,7 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
) )
from openfold.utils import geometry from openfold.utils import geometry
from openfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -95,11 +95,22 @@ class InputEmbedder(nn.Module): ...@@ -95,11 +95,22 @@ class InputEmbedder(nn.Module):
d = ri[..., None] - ri[..., None, :] d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries).type(ri.dtype) reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
return self.linear_relpos(oh) d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
batch: Dict containing batch: Dict containing
...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module): ...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
tf = batch["target_feat"]
ri = batch["residue_index"]
msa = batch["msa_feat"]
# [*, N_res, c_z] # [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf) tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module): ...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(inplace_safe):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace( bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module):
device=x.device, device=x.device,
requires_grad=False, requires_grad=False,
) )
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = d + self.layer_norm_z(z) z_update = add(z_update, d, inplace_safe)
return m_update, z_update return m_update, z_update
...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__( def __init__(
self, self,
c_in: int, c_in: int,
...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module): ...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True _mask_trans=True,
use_lma=False,
inplace_safe=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx), lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch, batch,
) )
single_template_embeds = {} # [*, N, N, C_t]
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module): ...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module):
).to(z.dtype) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
template_embeds.append(single_template_embeds) del t
template_embeds = dict_multimap( if (not inplace_safe):
partial(torch.cat, dim=templ_dim), t_pair = torch.stack(pair_embeds, dim=templ_dim)
template_embeds,
) del pair_embeds
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
del t_pair
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size, use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {} ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret return ret
...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
use_lma=False,
inplace_safe=False
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment