Commit dba44612 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Resolve merge conflicts

parents 4bd1b4d5 576174f0
FROM nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04
RUN apt-get update && apt-get install -y wget cuda-minimal-build-11-0 git RUN apt-get update && apt-get install -y wget cuda-minimal-build-10-2 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
...@@ -21,4 +21,5 @@ COPY lib/openmm.patch /opt/openfold/lib/openmm.patch ...@@ -21,4 +21,5 @@ COPY lib/openmm.patch /opt/openfold/lib/openmm.patch
RUN wget -q -P /opt/openfold/openfold/resources \ RUN wget -q -P /opt/openfold/openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
RUN patch -p0 -d /opt/conda/lib/python3.7/site-packages/ < /opt/openfold/lib/openmm.patch RUN patch -p0 -d /opt/conda/lib/python3.7/site-packages/ < /opt/openfold/lib/openmm.patch
RUN python3 /opt/openfold/setup.py install WORKDIR /opt/openfold
RUN python3 setup.py install
...@@ -15,12 +15,26 @@ cases where the *Nature* paper differs from the source, we always defer to the ...@@ -15,12 +15,26 @@ cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is built to support inference with AlphaFold's original JAX weights. OpenFold is built to support inference with AlphaFold's original JAX weights.
Try it out with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb). It's also faster than the official code on GPU. Try it out for yourself with
our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with [DeepSpeed](https://github.com/microsoft/deepspeed) and with either `fp16` with [DeepSpeed](https://github.com/microsoft/deepspeed) and with either `fp16`
or `bfloat16` half-precision. or `bfloat16` half-precision.
OpenFold is equipped with an implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)), which
enables inference on extremely long chains.
We've modified [FastFold](https://github.com/hpcaitech/FastFold)'s custom CUDA
kernels to support in-place attention during inference and training. These use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively.
We also make available efficient scripts for generating alignments. We've
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
## Installation (Linux) ## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence All Python dependencies are specified in `environment.yml`. For producing sequence
...@@ -48,6 +62,12 @@ To deactivate it, run: ...@@ -48,6 +62,12 @@ To deactivate it, run:
source scripts/deactivate_conda_env.sh source scripts/deactivate_conda_env.sh
``` ```
With the environment active, compile OpenFold's CUDA kernels with
```bash
python3 setup.py install
```
To install the HH-suite to `/usr/bin`, run To install the HH-suite to `/usr/bin`, run
```bash ```bash
...@@ -129,13 +149,6 @@ to `None` in the config. ...@@ -129,13 +149,6 @@ to `None` in the config.
### Training ### Training
After activating the OpenFold environment with
`source scripts/activate_conda_env.sh`, install OpenFold by running
```bash
python setup.py install
```
To train the model, you will first need to precompute protein alignments. To train the model, you will first need to precompute protein alignments.
You have two options. You can use the same procedure DeepMind used by running You have two options. You can use the same procedure DeepMind used by running
......
{ {
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"eps": 1e-05
}
},
"fp16": { "fp16": {
"enabled": true, "enabled": false,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"amp": { "amp": {
...@@ -15,7 +8,7 @@ ...@@ -15,7 +8,7 @@
"opt_level": "O2" "opt_level": "O2"
}, },
"bfloat16": { "bfloat16": {
"enabled": false "enabled": true
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
......
...@@ -6,7 +6,7 @@ channels: ...@@ -6,7 +6,7 @@ channels:
dependencies: dependencies:
- pip: - pip:
- biopython==1.79 - biopython==1.79
- deepspeed==0.5.3 - deepspeed==0.5.9
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- numpy==1.21.2 - numpy==1.21.2
...@@ -15,7 +15,7 @@ dependencies: ...@@ -15,7 +15,7 @@ dependencies:
- scipy==1.7.1 - scipy==1.7.1
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.0 - pytorch_lightning==1.5.10
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- pytorch::pytorch=1.10.* - pytorch::pytorch=1.10.*
- conda-forge::python=3.7 - conda-forge::python=3.7
...@@ -23,6 +23,8 @@ dependencies: ...@@ -23,6 +23,8 @@ dependencies:
- conda-forge::pip - conda-forge::pip
- conda-forge::openmm=7.5.1 - conda-forge::openmm=7.5.1
- conda-forge::pdbfixer - conda-forge::pdbfixer
- conda-forge::cudatoolkit==10.2.*
- conda-forge::cudatoolkit-dev==10.*
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
...@@ -267,6 +267,7 @@ config = mlc.ConfigDict( ...@@ -267,6 +267,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9, "clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
"uniform_recycling": True, "uniform_recycling": True,
"distillation_prob": 0.75,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
......
...@@ -176,8 +176,8 @@ def make_protein_features( ...@@ -176,8 +176,8 @@ def make_protein_features(
def make_pdb_features( def make_pdb_features(
protein_object: protein.Protein, protein_object: protein.Protein,
description: str, description: str,
confidence_threshold: float = 0.5,
is_distillation: bool = True, is_distillation: bool = True,
confidence_threshold: float = 50.,
) -> FeatureDict: ) -> FeatureDict:
pdb_feats = make_protein_features( pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True protein_object, description, _is_distillation=True
...@@ -186,9 +186,7 @@ def make_pdb_features( ...@@ -186,9 +186,7 @@ def make_pdb_features(
if(is_distillation): if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1) high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence): pdb_feats["all_atom_mask"] *= high_confidence[..., None]
if(not confident):
pdb_feats["all_atom_mask"][i] = 0
return pdb_feats return pdb_feats
...@@ -832,13 +830,24 @@ class DataPipeline: ...@@ -832,13 +830,24 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None, _alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
""" """
with open(pdb_path, 'r') as f: if(_structure_index is not None):
pdb_str = f.read() db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id) protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype) input_sequence = _aatype_to_str_sequence(protein_object.aatype)
...@@ -846,7 +855,7 @@ class DataPipeline: ...@@ -846,7 +855,7 @@ class DataPipeline:
pdb_feats = make_pdb_features( pdb_feats = make_pdb_features(
protein_object, protein_object,
description, description,
is_distillation is_distillation=is_distillation
) )
hits = self._parse_template_hits( hits = self._parse_template_hits(
......
...@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein): ...@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def make_one_hot(x, num_classes): def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes) x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot return x_one_hot
...@@ -92,9 +92,9 @@ def fix_templates_aatype(protein): ...@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
) )
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand( new_order = torch.tensor(
num_templates, -1 new_order_list, dtype=torch.int64, device=protein["aatype"].device,
) ).expand(num_templates, -1)
protein["template_aatype"] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"] new_order, 1, index=protein["template_aatype"]
) )
...@@ -106,7 +106,8 @@ def correct_msa_restypes(protein): ...@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc.""" """Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype [new_order_list] * protein["msa"].shape[1],
device=protein["msa"].device,
).transpose(0, 1) ).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"]) protein["msa"] = torch.gather(new_order, 0, protein["msa"])
...@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if seed is not None: if seed is not None:
g.manual_seed(seed) g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1 shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
dim=0
)
num_sel = min(max_seq, num_seq) num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split( sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel] index_order, [num_sel, num_seq - num_sel]
...@@ -242,7 +246,7 @@ def delete_extra_msa(protein): ...@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def block_delete_msa(protein, config): def block_delete_msa(protein, config):
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor( block_num_seq = torch.floor(
torch.tensor(num_seq, dtype=torch.float32) torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
* config.msa_fraction_per_block * config.msa_fraction_per_block
).to(torch.int32) ).to(torch.int32)
...@@ -275,7 +279,11 @@ def block_delete_msa(protein, config): ...@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@curry1 @curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat( weights = torch.cat(
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)], [
torch.ones(21, device=protein["msa"].device),
gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
torch.zeros(1, device=protein["msa"].device)
],
0, 0,
) )
...@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments): ...@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
) )
segment_ids = segment_ids.expand(data.shape) segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:]) shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) tensor = (
torch.zeros(*shape, device=segment_ids.device)
.scatter_add_(0, segment_ids, data.float())
)
tensor = tensor.type(data.dtype) tensor = tensor.type(data.dtype)
return tensor return tensor
...@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""): ...@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@curry1 @curry1
def add_constant_field(protein, key, value): def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value) protein[key] = torch.tensor(value, device=protein["msa"].device)
return protein return protein
...@@ -454,6 +465,7 @@ def make_masked_msa(protein, config, replace_fraction, seed): ...@@ -454,6 +465,7 @@ def make_masked_msa(protein, config, replace_fraction, seed):
1.0 - config.profile_prob - config.same_prob - config.uniform_prob 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
) )
assert mask_prob >= 0.0 assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad( categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob, categorical_probs, pad_shapes, value=mask_prob,
) )
...@@ -656,7 +668,11 @@ def make_atom14_masks(protein): ...@@ -656,7 +668,11 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) batch = tree_map(
lambda n: torch.tensor(n, device=batch["aatype"].device),
batch,
np.ndarray
)
out = make_atom14_masks(batch) out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out) out = tensor_tree_map(lambda t: np.array(t), out)
return out return out
......
...@@ -40,10 +40,11 @@ def np_to_tensor_dict( ...@@ -40,10 +40,11 @@ def np_to_tensor_dict(
Returns: Returns:
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
tensor_dict = { tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features k: torch.tensor(v) for k, v in np_example.items() if k in features
} }
return tensor_dict return tensor_dict
......
...@@ -327,8 +327,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -327,8 +327,6 @@ class RecyclingEmbedder(nn.Module):
self.no_bins = no_bins self.no_bins = no_bins
self.inf = inf self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z) self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m) self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
...@@ -353,15 +351,14 @@ class RecyclingEmbedder(nn.Module): ...@@ -353,15 +351,14 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if self.bins is None: bins = torch.linspace(
self.bins = torch.linspace( self.min_bin,
self.min_bin, self.max_bin,
self.max_bin, self.no_bins,
self.no_bins, dtype=x.dtype,
dtype=x.dtype, device=x.device,
device=x.device, requires_grad=False,
requires_grad=False, )
)
# [*, N, C_m] # [*, N, C_m]
m_update = self.layer_norm_m(m) m_update = self.layer_norm_m(m)
...@@ -369,7 +366,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -369,7 +366,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode. # This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I # I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time. # couldn't find in time.
squared_bins = self.bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
) )
......
...@@ -352,20 +352,31 @@ class ExtraMSABlock(nn.Module): ...@@ -352,20 +352,31 @@ class ExtraMSABlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024, _chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer( def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(torch.is_grad_enabled()):
m1 = m1 + m2
else:
m1 += m2
return m1
m = add(m, self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
m.clone(), m.clone() if torch.is_grad_enabled() else m,
z=z.clone(), z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_memory_efficient_kernel=not _chunk_logits,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
) ))
def fn(m, z): def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core( m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
) )
...@@ -548,11 +559,14 @@ class ExtraMSAStack(nn.Module): ...@@ -548,11 +559,14 @@ class ExtraMSAStack(nn.Module):
eps: float, eps: float,
ckpt: bool, ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for _ in range(no_blocks): for _ in range(no_blocks):
block = ExtraMSABlock( block = ExtraMSABlock(
...@@ -569,7 +583,7 @@ class ExtraMSAStack(nn.Module): ...@@ -569,7 +583,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
ckpt=ckpt, ckpt=ckpt if chunk_msa_attn else False,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -593,28 +607,36 @@ class ExtraMSAStack(nn.Module): ...@@ -593,28 +607,36 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask Optional [*, N_res, N_res] pair mask
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
#checkpoint_fn = get_checkpoint_fn() if(not self.chunk_msa_attn):
#blocks = [ checkpoint_fn = get_checkpoint_fn()
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks blocks = [
#] partial(
b,
#def dodo(b, *args): msa_mask=msa_mask,
# torch.cuda.empty_cache() pair_mask=pair_mask,
# return b(*args) chunk_size=chunk_size,
_chunk_logits=None
#blocks = [partial(dodo, b) for b in blocks] ) for b in self.blocks
]
def clear_cache(b, *args):
torch.cuda.empty_cache()
return b(*args)
#for b in blocks: if(self.clear_cache_between_blocks):
# if(torch.is_grad_enabled()): blocks = [partial(clear_cache, b) for b in blocks]
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks: for b in blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return z return z
...@@ -173,9 +173,10 @@ class AlphaFold(nn.Module): ...@@ -173,9 +173,10 @@ class AlphaFold(nn.Module):
# altogether. We zero them this way instead of computing them # altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying # conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training. # implications for DDP training.
if(not _recycle): # EDIT: This has since been removed from the official codebase (2cd61a)
m_1_prev_emb *= 0 # if(not _recycle):
z_prev_emb *= 0 # m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -79,20 +79,33 @@ class MSAAttention(nn.Module): ...@@ -79,20 +79,33 @@ class MSAAttention(nn.Module):
) )
self.mha = Attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in,
self.c_in,
self.c_in,
self.c_hidden,
self.no_heads,
) )
@torch.jit.ignore @torch.jit.ignore
def _chunk(self, def _chunk(self,
m: torch.Tensor, m: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
chunk_size: int, chunk_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
)
return chunk_layer( return chunk_layer(
self.mha, mha,
{"q_x": m, "kv_x": m, "biases": biases}, {
"q_x": m,
"kv_x": m,
"biases": biases,
},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2])
) )
def _prep_inputs(self, def _prep_inputs(self,
...@@ -113,13 +126,6 @@ class MSAAttention(nn.Module): ...@@ -113,13 +126,6 @@ class MSAAttention(nn.Module):
# [*, N_seq, 1, 1, N_res] # [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if (self.pair_bias and if (self.pair_bias and
z is not None and # For the z is not None and # For the
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
...@@ -144,6 +150,11 @@ class MSAAttention(nn.Module): ...@@ -144,6 +150,11 @@ class MSAAttention(nn.Module):
chunk_logits: int, chunk_logits: int,
checkpoint: bool, checkpoint: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
Saves memory in the extra MSA stack. Probably obviated by our fused
attention kernel, which is now used by default.
"""
MSA_DIM = -4 MSA_DIM = -4
def _get_qkv(m, z): def _get_qkv(m, z):
...@@ -181,6 +192,7 @@ class MSAAttention(nn.Module): ...@@ -181,6 +192,7 @@ class MSAAttention(nn.Module):
z: Optional[torch.Tensor] = None, z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -212,12 +224,13 @@ class MSAAttention(nn.Module): ...@@ -212,12 +224,13 @@ class MSAAttention(nn.Module):
biases.append(z) biases.append(z)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, chunk_size) m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
else: else:
m = self.mha( m = self.mha(
q_x=m, q_x=m,
kv_x=m, kv_x=m,
biases=biases biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
) )
return m return m
...@@ -291,7 +304,8 @@ class MSAColumnAttention(nn.Module): ...@@ -291,7 +304,8 @@ class MSAColumnAttention(nn.Module):
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
...@@ -24,6 +23,7 @@ import torch.nn as nn ...@@ -24,6 +23,7 @@ import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
...@@ -199,8 +199,9 @@ class LayerNorm(nn.Module): ...@@ -199,8 +199,9 @@ class LayerNorm(nn.Module):
return out return out
@torch.jit.ignore @torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
""" """
Softmax, but without automatic casting to fp32 when the input is of Softmax, but without automatic casting to fp32 when the input is of
type bfloat16 type bfloat16
...@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
#@torch.jit.script #@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K] # [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0)) key = permute_final_dims(key, (1, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul(query, key) a = torch.matmul(query, key)
...@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias ...@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
for b in biases: for b in biases:
a += b a += b
a = softmax(a, -1) a = softmax_no_cast(a, -1)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
a = torch.matmul(a, value) a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
a = a.transpose(-2, -3)
return a return a
...@@ -254,7 +246,8 @@ def _attention_chunked_trainable( ...@@ -254,7 +246,8 @@ def _attention_chunked_trainable(
def _checkpointable_attention(q, k, v, b1, b2): def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None] bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs) a = _attention(q, k, v, bs)
return a
o_chunks = [] o_chunks = []
checkpoint_fn = get_checkpoint_fn() checkpoint_fn = get_checkpoint_fn()
...@@ -289,7 +282,8 @@ def _attention_chunked_trainable( ...@@ -289,7 +282,8 @@ def _attention_chunked_trainable(
] ]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunk = o_chunk.transpose(-2, -3)
o_chunks.append(o_chunk) o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim) o = torch.cat(o_chunks, dim=chunk_dim)
...@@ -374,6 +368,11 @@ class Attention(nn.Module): ...@@ -374,6 +368,11 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden) q /= math.sqrt(self.c_hidden)
return q, k, v return q, k, v
...@@ -402,6 +401,7 @@ class Attention(nn.Module): ...@@ -402,6 +401,7 @@ class Attention(nn.Module):
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor, kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
q_chunk_size: Optional[int] = None, q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None, kv_chunk_size: Optional[int] = None,
...@@ -414,8 +414,15 @@ class Attention(nn.Module): ...@@ -414,8 +414,15 @@ class Attention(nn.Module):
[*, K, C_k] key data [*, K, C_k] key data
biases: biases:
List of biases that broadcast to [*, H, Q, K] List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_lma: use_lma:
Whether to use low-memory attention Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size: q_chunk_size:
Query chunk size (for LMA) Query chunk size (for LMA)
kv_chunk_size: kv_chunk_size:
...@@ -430,18 +437,32 @@ class Attention(nn.Module): ...@@ -430,18 +437,32 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must " "If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided" "be provided"
) )
if(use_memory_efficient_kernel and use_lma):
raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma"
)
# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
if(use_lma): # [*, Q, H, C_hidden]
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3)
elif(use_lma):
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
o = self._wrap_up(o, q_x) o = self._wrap_up(o, q_x)
...@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module): ...@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module):
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = softmax(a) a = softmax_no_cast(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
o = torch.matmul( o = torch.matmul(
......
...@@ -553,60 +553,3 @@ def run_pipeline( ...@@ -553,60 +553,3 @@ def run_pipeline(
) )
iteration += 1 iteration += 1
return ret return ret
def get_initial_energies(
pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
)
stiffness = stiffness * ENERGY / (LENGTH ** 2)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
)
simulation = openmm_app.Simulation(
openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"),
)
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error(
"Error getting initial energy, returning large value %s", e
)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
...@@ -2,12 +2,14 @@ import os ...@@ -2,12 +2,14 @@ import os
import glob import glob
import importlib as importlib import importlib as importlib
from . import kernel
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [ __all__ = [
os.path.basename(f)[:-3] os.path.basename(f)[:-3]
for f in _files for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py") if os.path.isfile(f) and not f.endswith("__init__.py")
] ] + ["kernel"]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules: for _m in _modules:
globals()[_m[0]] = _m[1] globals()[_m[0]] = _m[1]
......
...@@ -107,7 +107,10 @@ def dgram_from_positions( ...@@ -107,7 +107,10 @@ def dgram_from_positions(
def build_template_pair_feat( def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8 batch,
min_bin, max_bin, no_bins,
use_unit_vector=False,
eps=1e-20, inf=1e8
): ):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from functools import reduce
from operator import mul
import torch
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
SUPPORTED_DTYPES = [torch.float32, torch.bfloat16]
class AttentionCoreFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias_1=None, bias_2=None):
if(bias_1 is None and bias_2 is not None):
raise ValueError("bias_1 must be specified before bias_2")
if(q.dtype not in SUPPORTED_DTYPES):
raise ValueError("Unsupported datatype")
q = q.contiguous()
k = k.contiguous()
# [*, H, Q, K]
attention_logits = torch.matmul(
q, k.transpose(-1, -2),
)
if(bias_1 is not None):
attention_logits += bias_1
if(bias_2 is not None):
attention_logits += bias_2
attn_core_inplace_cuda.forward_(
attention_logits,
reduce(mul, attention_logits.shape[:-1]),
attention_logits.shape[-1],
)
o = torch.matmul(attention_logits, v)
ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None
ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None
ctx.save_for_backward(q, k, v, attention_logits)
return o
@staticmethod
def backward(ctx, grad_output):
q, k, v, attention_logits = ctx.saved_tensors
grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None
grad_v = torch.matmul(
attention_logits.transpose(-1, -2),
grad_output
)
attn_core_inplace_cuda.backward_(
attention_logits,
grad_output.contiguous(),
v.contiguous(), # v is implicitly transposed in the kernel
reduce(mul, attention_logits.shape[:-1]),
attention_logits.shape[-1],
grad_output.shape[-1],
)
if(ctx.bias_1_shape is not None):
grad_bias_1 = torch.sum(
attention_logits,
dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1),
keepdim=True,
)
if(ctx.bias_2_shape is not None):
grad_bias_2 = torch.sum(
attention_logits,
dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1),
keepdim=True,
)
grad_q = torch.matmul(
attention_logits, k
)
grad_k = torch.matmul(
q.transpose(-1, -2), attention_logits,
).transpose(-1, -2)
return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2
attention_core = AttentionCoreFunction.apply
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
#include <torch/extension.h>
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
);
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward_",
&attn_softmax_inplace_forward_,
"Softmax forward (CUDA)"
);
m.def(
"backward_",
&attn_softmax_inplace_backward_,
"Softmax backward (CUDA)"
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment