Commit 4f53624d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add tracing, TensorFloat32 utilization, and FlashAttention

parent db43f4ec
...@@ -23,7 +23,11 @@ def enforce_config_constraints(config): ...@@ -23,7 +23,11 @@ def enforce_config_constraints(config):
( (
"model.template.average_templates", "model.template.average_templates",
"model.template.offload_templates" "model.template.offload_templates"
) ),
(
"globals.use_lma",
"globals.use_flash",
),
] ]
for s1, s2 in mutually_exclusive_bools: for s1, s2 in mutually_exclusive_bools:
...@@ -315,7 +319,12 @@ config = mlc.ConfigDict( ...@@ -315,7 +319,12 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
"use_flash": True,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
......
...@@ -361,6 +361,7 @@ class EvoformerBlock(nn.Module): ...@@ -361,6 +361,7 @@ class EvoformerBlock(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
...@@ -390,12 +391,14 @@ class EvoformerBlock(nn.Module): ...@@ -390,12 +391,14 @@ class EvoformerBlock(nn.Module):
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
m = add(m, m = add(m,
self.msa_att_col( self.msa_att_col(
m, m,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
...@@ -666,6 +669,7 @@ class EvoformerStack(nn.Module): ...@@ -666,6 +669,7 @@ class EvoformerStack(nn.Module):
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool, use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
inplace_safe: bool, inplace_safe: bool,
...@@ -678,6 +682,7 @@ class EvoformerStack(nn.Module): ...@@ -678,6 +682,7 @@ class EvoformerStack(nn.Module):
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
...@@ -756,6 +761,7 @@ class EvoformerStack(nn.Module): ...@@ -756,6 +761,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -773,6 +779,9 @@ class EvoformerStack(nn.Module): ...@@ -773,6 +779,9 @@ class EvoformerStack(nn.Module):
Inference-time subbatch size. Acts as a minimum if Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
...@@ -786,6 +795,7 @@ class EvoformerStack(nn.Module): ...@@ -786,6 +795,7 @@ class EvoformerStack(nn.Module):
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -947,10 +957,10 @@ class ExtraMSAStack(nn.Module): ...@@ -947,10 +957,10 @@ class ExtraMSAStack(nn.Module):
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -380,6 +380,7 @@ class AlphaFold(nn.Module): ...@@ -380,6 +380,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -392,6 +393,7 @@ class AlphaFold(nn.Module): ...@@ -392,6 +393,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
......
...@@ -18,6 +18,9 @@ from typing import Optional, Callable, List, Tuple, Sequence ...@@ -18,6 +18,9 @@ from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed import deepspeed
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
...@@ -407,8 +410,10 @@ class Attention(nn.Module): ...@@ -407,8 +410,10 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -427,25 +432,34 @@ class Attention(nn.Module): ...@@ -427,25 +432,34 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead implementation is used instead
q_chunk_size: lma_q_chunk_size:
Query chunk size (for LMA) Query chunk size (for LMA)
kv_chunk_size: lma_kv_chunk_size:
Key/Value chunk size (for LMA) Key/Value chunk size (for LMA)
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(biases is None):
biases = []
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)): if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError( raise ValueError(
"If use_lma is specified, q_chunk_size and kv_chunk_size must " "If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided" "be provided"
) )
if(use_memory_efficient_kernel and use_lma):
if(use_flash and biases is not None):
raise ValueError( raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma" "use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
) )
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
raise ValueError(
"Choose at most one alternative attention algorithm"
)
if(biases is None):
biases = []
# [*, H, Q/K, C_hidden] # [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
...@@ -463,8 +477,10 @@ class Attention(nn.Module): ...@@ -463,8 +477,10 @@ class Attention(nn.Module):
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_flash):
o = _flash_attn(q, k, v, flash_mask)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
...@@ -623,3 +639,64 @@ def _lma( ...@@ -623,3 +639,64 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o return o
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()
# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])
# Flattened batch size
batch_size = q.shape[0]
# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])
q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)
# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
)
# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)
out = out.to(dtype=dtype)
return out
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import partialmethod
import numpy as np
import torch
from openfold.utils.tensor_utils import tensor_tree_map
def pad_feature_dict_seq(feature_dict, seqlen):
""" Pads the sequence length of a feature dict. Used for tracing. """
# The real sequence length can't be longer than the desired one
true_n = feature_dict["aatype"].shape[-2]
assert(true_n <= seqlen)
new_feature_dict = {}
feat_seq_dims = {
"aatype": -2,
"between_segment_residues": -1,
"residue_index": -1,
"seq_length": -1,
"deletion_matrix_int": -1,
"msa": -1,
"num_alignments": -1,
"template_aatype": -2,
"template_all_atom_mask": -2,
"template_all_atom_positions": -3,
}
for k,v in feature_dict.items():
if(k not in feat_seq_dims):
new_feature_dict[k] = v
continue
seq_dim = feat_seq_dims[k]
padded_shape = list(v.shape)
padded_shape[seq_dim] = seqlen
new_value = np.zeros(padded_shape, dtype=v.dtype)
new_value[tuple(slice(0, s) for s in v.shape)] = v
new_feature_dict[k] = new_value
new_feature_dict["seq_length"][0] = seqlen
return new_feature_dict
def trace_model_(model, sample_input):
# Grab the inputs to the final recycling iteration
feats = tensor_tree_map(lambda t: t[..., -1], sample_input)
# Gather some metadata
n = feats["aatype"].shape[-1]
msa_depth = feats["true_msa"].shape[-2]
extra_msa_depth = feats["extra_msa"].shape[-2]
no_templates = feats["template_aatype"].shape[-2]
device = feats["aatype"].device
seq_mask = feats["seq_mask"].to(device)
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"].to(device)
extra_msa_mask = feats["extra_msa_mask"].to(device)
template_pair_mask = torch.stack([pair_mask] * no_templates, dim=-3)
# Create some fake representations with the correct shapes
m = torch.rand(msa_depth, n, model.globals.c_m).to(device)
z = torch.rand(n, n, model.globals.c_z).to(device)
t = torch.rand(no_templates, n, n, model.globals.c_t).to(device)
a = torch.rand(extra_msa_depth, n, model.globals.c_e).to(device)
# We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't
# baked into the trace. There's no need to run the entire thing,
# though; we just need to run one block from each transformer stack.
evoformer_blocks = model.evoformer.blocks
model.evoformer.blocks = evoformer_blocks[:1]
extra_msa_blocks = model.extra_msa_stack.blocks
model.extra_msa_stack.blocks = extra_msa_blocks[:1]
if(model.template_config.enabled):
template_pair_stack_blocks = model.template_pair_stack.blocks
model.template_pair_stack.blocks = template_pair_stack_blocks[:1]
single_recycling_iter_input = tensor_tree_map(
lambda t: t[..., :1], sample_input,
)
with torch.no_grad():
_ = model(single_recycling_iter_input)
model.evoformer.blocks = evoformer_blocks
model.extra_msa_stack.blocks = extra_msa_blocks
del evoformer_blocks, extra_msa_blocks
if(model.template_config.enabled):
model.template_pair_stack.blocks = template_pair_stack_blocks
del template_pair_stack_blocks
def get_tuned_chunk_size(module):
tuner = module.chunk_size_tuner
chunk_size = tuner.cached_chunk_size
# After our trial run above, this should always be set
assert(chunk_size is not None)
return chunk_size
# Fetch the resulting chunk sizes
evoformer_chunk_size = model.globals.chunk_size
if(model.evoformer.chunk_size_tuner is not None):
evoformer_chunk_size = get_tuned_chunk_size(model.evoformer)
extra_msa_chunk_size = model.globals.chunk_size
if(model.extra_msa_stack.chunk_size_tuner is not None):
extra_msa_chunk_size = get_tuned_chunk_size(model.extra_msa_stack)
if(model.template_config.enabled):
template_pair_stack_chunk_size = model.globals.chunk_size
if(model.template_pair_stack.chunk_size_tuner is not None):
template_pair_stack_chunk_size = get_tuned_chunk_size(
model.template_pair_stack
)
def trace_block(block, block_inputs):
# Yes, yes, I know
with contextlib.redirect_stderr(None):
traced_block = torch.jit.trace(block, block_inputs)
traced_block = torch.jit.optimize_for_inference(traced_block)
# All trace inputs need to be tensors. This wrapper takes care of that
def traced_block_wrapper(*args, **kwargs):
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t
args = [to_tensor(a) for a in args]
kwargs = {k: to_tensor(v) for k,v in kwargs.items()}
return traced_block(*args, **kwargs)
return traced_block_wrapper
def verify_arg_order(fn, arg_list):
""" Because it's difficult to specify keyword arguments of Module
functions during tracing, we need to pass them as a tuple. As a
sanity check, we manually verify their order here.
"""
fn_arg_names = fn.__code__.co_varnames
# Remove the "self" parameter
assert(fn_arg_names[0] == "self")
fn_arg_names = fn_arg_names[1:]
# Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = zip(fn_arg_names, [n for n, _ in arg_list])
assert(all([n1 == n2 for n1, n2 in name_tups]))
evoformer_attn_chunk_size = max(
model.globals.chunk_size, evoformer_chunk_size // 4
)
evoformer_arg_tuples = [
("m", m),
("z", z),
("msa_mask", msa_mask),
("pair_mask", pair_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)),
("inplace_safe", torch.tensor(1)),
("_mask_trans", torch.tensor(model.config._mask_trans)),
("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
]
verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples)
evoformer_args = [arg for _, arg in evoformer_arg_tuples]
with torch.no_grad():
traced_evoformer_stack = []
for b in model.evoformer.blocks:
traced_block = trace_block(b, evoformer_args)
traced_evoformer_stack.append(traced_block)
del model.evoformer.blocks
model.evoformer.blocks = traced_evoformer_stack
# extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4
# )
# extra_msa_arg_tuples = [
# ("m", a),
# ("z", z),
# ("msa_mask", extra_msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(extra_msa_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(extra_msa_attn_chunk_size)),
# ]
# verify_arg_order(
# model.extra_msa_stack.blocks[0].forward, extra_msa_arg_tuples
# )
# extra_msa_args = [arg for _, arg in extra_msa_arg_tuples]
# with torch.no_grad():
# traced_extra_msa_stack = []
# for b in model.extra_msa_stack.blocks:
# traced_block = trace_block(b, extra_msa_args)
# traced_extra_msa_stack.append(traced_block)
#
# del model.extra_msa_stack.blocks
# model.extra_msa_stack.blocks = traced_extra_msa_stack
# if(model.template_config.enabled):
# template_pair_stack_attn_chunk_size = max(
# model.globals.chunk_size, template_pair_stack_chunk_size // 4
# )
# template_pair_stack_arg_tuples = [
# ("z", t),
# ("mask", template_pair_mask),
# ("chunk_size", torch.tensor(template_pair_stack_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(
# template_pair_stack_attn_chunk_size
# )),
# ]
# verify_arg_order(
# model.template_pair_stack.blocks[0].forward,
# template_pair_stack_arg_tuples
# )
# template_pair_stack_args = [
# arg for _, arg in template_pair_stack_arg_tuples
# ]
#
# with torch.no_grad():
# traced_template_pair_stack = []
# for b in model.template_pair_stack.blocks:
# traced_block = trace_block(b, template_pair_stack_args)
# traced_template_pair_stack.append(traced_block)
#
# del model.template_pair_stack.blocks
# model.template_pair_stack.blocks = traced_template_pair_stack
# We need to do another dry run after tracing to allow the model to reach
# top speeds. Why, I don't know.
two_recycling_iter_input = tensor_tree_map(
lambda t: t[..., :2], sample_input,
)
with torch.no_grad():
_ = model(two_recycling_iter_input)
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from copy import deepcopy
from datetime import date from datetime import date
import gc
import logging import logging
import math
import numpy as np import numpy as np
import os import os
from copy import deepcopy
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import pickle import pickle
from pytorch_lightning.utilities.deepspeed import ( from pytorch_lightning.utilities.deepspeed import (
...@@ -31,7 +34,19 @@ import time ...@@ -31,7 +34,19 @@ import time
import torch import torch
import re import re
from openfold.config import model_config torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
from openfold.config import model_config, NUM_RES
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
...@@ -43,13 +58,14 @@ from openfold.utils.import_weights import ( ...@@ -43,13 +58,14 @@ from openfold.utils.import_weights import (
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
from openfold.utils.trace_utils import (
pad_feature_dict_seq,
trace_model_,
)
from scripts.utils import add_data_args from scripts.utils import add_data_args
logging.basicConfig() TRACING_INTERVAL = 50
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def precompute_alignments(tags, seqs, alignment_dir, args): def precompute_alignments(tags, seqs, alignment_dir, args):
...@@ -59,10 +75,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -59,10 +75,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
...@@ -78,18 +94,21 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -78,18 +94,21 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner.run( alignment_runner.run(
tmp_fasta_path, local_alignment_dir tmp_fasta_path, local_alignment_dir
) )
else:
logger.info(
f"Using precomputed alignments for {tag} at {alignment_dir}..."
)
# Remove temporary FASTA file # Remove temporary FASTA file
os.remove(tmp_fasta_path) os.remove(tmp_fasta_path)
def round_up_seqlen(seqlen):
return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL
def run_model(model, batch, tag, args): def run_model(model, batch, tag, args):
with torch.no_grad(): with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
# Disable templates if there aren't any in the batch # Disable templates if there aren't any in the batch
model.config.template.enabled = model.config.template.enabled and any([ model.config.template.enabled = model.config.template.enabled and any([
"template_" in k for k in batch "template_" in k for k in batch
...@@ -99,7 +118,7 @@ def run_model(model, batch, tag, args): ...@@ -99,7 +118,7 @@ def run_model(model, batch, tag, args):
t = time.perf_counter() t = time.perf_counter()
out = model(batch) out = model(batch)
logger.info(f"Inference time: {time.perf_counter() - t}") logger.info(f"Inference time: {time.perf_counter() - t}")
return out return out
...@@ -208,6 +227,7 @@ def generate_feature_dict( ...@@ -208,6 +227,7 @@ def generate_feature_dict(
return feature_dict return feature_dict
def get_model_basename(model_path): def get_model_basename(model_path):
return os.path.splitext( return os.path.splitext(
os.path.basename( os.path.basename(
...@@ -215,6 +235,7 @@ def get_model_basename(model_path): ...@@ -215,6 +235,7 @@ def get_model_basename(model_path):
) )
)[0] )[0]
def make_output_directory(output_dir, model_name, multiple_model_mode): def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode: if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name) prediction_dir = os.path.join(output_dir, "predictions", model_name)
...@@ -223,6 +244,7 @@ def make_output_directory(output_dir, model_name, multiple_model_mode): ...@@ -223,6 +244,7 @@ def make_output_directory(output_dir, model_name, multiple_model_mode):
os.makedirs(prediction_dir, exist_ok=True) os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir return prediction_dir
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path): def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0 model_count = 0
if openfold_checkpoint_path: if openfold_checkpoint_path:
...@@ -231,6 +253,7 @@ def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path): ...@@ -231,6 +253,7 @@ def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count += len(jax_param_path.split(",")) model_count += len(jax_param_path.split(","))
return model_count return model_count
def load_models_from_command_line(args, config): def load_models_from_command_line(args, config):
# Create the output directory # Create the output directory
...@@ -295,14 +318,23 @@ def load_models_from_command_line(args, config): ...@@ -295,14 +318,23 @@ def load_models_from_command_line(args, config):
"be specified." "be specified."
) )
def list_files_with_extensions(dir, extensions): def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)] return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args): def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset) config = model_config(args.config_preset)
if(args.trace_model):
if(not config.data.predict.fixed_size):
raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config"
)
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -319,7 +351,11 @@ def main(args): ...@@ -319,7 +351,11 @@ def main(args):
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(2**32)
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -327,8 +363,9 @@ def main(args): ...@@ -327,8 +363,9 @@ def main(args):
alignment_dir = os.path.join(output_dir_base, "alignments") alignment_dir = os.path.join(output_dir_base, "alignments")
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
logger.info(f"Using precomputed alignments at {alignment_dir}...")
tag_list = []
seq_list = []
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")): for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
...@@ -337,45 +374,84 @@ def main(args): ...@@ -337,45 +374,84 @@ def main(args):
tags, seqs = parse_fasta(data) tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
precompute_alignments(tags, seqs, alignment_dir, args) tag_list.append(tag)
seq_list.append(seqs)
seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {}
for model, output_directory in load_models_from_command_line(args, config):
cur_tracing_interval = 0
for tag, seqs in sorted_targets:
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
feature_dict = generate_feature_dict( # Does nothing if the alignments have already been computed
tags, precompute_alignments(tags, seqs, alignment_dir, args)
seqs,
alignment_dir, feature_dict = feature_dicts.get(tag, None)
data_processor, if(feature_dict is None):
args, feature_dict = generate_feature_dict(
) tags,
seqs,
alignment_dir,
data_processor,
args,
)
processed_feature_dict = feature_processor.process_features( if(args.trace_model):
feature_dict, mode='predict', n = feature_dict["aatype"].shape[-2]
) rounded_seqlen = round_up_seqlen(n)
feature_dict = pad_feature_dict_seq(
feature_dict, rounded_seqlen,
)
feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items()
}
if(args.trace_model):
if(rounded_seqlen > cur_tracing_interval):
logger.info(
f"Tracing model at {rounded_seqlen} residues..."
)
t = time.perf_counter()
trace_model_(model, processed_feature_dict)
logger.info(
f"Tracing time: {time.perf_counter() - t}"
)
cur_tracing_interval = rounded_seqlen
for model, output_directory in load_models_from_command_line(args, config): out = run_model(model, processed_feature_dict, tag, args)
working_batch = deepcopy(processed_feature_dict)
out = run_model(model, working_batch, tag, args)
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
working_batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), working_batch) processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output( unrelaxed_protein = prep_output(
out, working_batch, feature_dict, feature_processor, args out,
processed_feature_dict,
feature_dict,
feature_processor,
args
) )
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb' output_directory, f'{output_name}_unrelaxed.pdb'
) )
# Output already exists
if os.path.exists(unrelaxed_output_path):
continue
with open(unrelaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein)) fp.write(protein.to_pdb(unrelaxed_protein))
...@@ -481,6 +557,12 @@ if __name__ == "__main__": ...@@ -481,6 +557,12 @@ if __name__ == "__main__":
"--multimer_ri_gap", type=int, default=200, "--multimer_ri_gap", type=int, default=200,
help="""Residue index offset between multiple sequences, if provided""" help="""Residue index offset between multiple sequences, if provided"""
) )
parser.add_argument(
"--trace_model", action="store_true", default=False,
help="""Whether to convert parts of each model to TorchScript.
Significantly improves runtime at the cost of lengthy
'compilation.' Useful for large batch jobs."""
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment